<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="4.2.2">Jekyll</generator><link href="https://nadavb.com/feed.xml" rel="self" type="application/atom+xml" /><link href="https://nadavb.com/" rel="alternate" type="text/html" /><updated>2025-09-24T20:12:47+03:00</updated><id>https://nadavb.com/feed.xml</id><title type="html">Nadav Benedek</title><entry><title type="html">Label Shift and Domain Adaptation in Machine Learning</title><link href="https://nadavb.com/Label-Shift-and-Domain-Adaptation-in-Machine-Learning/" rel="alternate" type="text/html" title="Label Shift and Domain Adaptation in Machine Learning" /><published>2024-10-05T11:27:00+03:00</published><updated>2024-10-05T11:27:00+03:00</updated><id>https://nadavb.com/Label%20Shift%20and%20Domain%20Adaptation%20in%20Machine%20Learning</id><content type="html" xml:base="https://nadavb.com/Label-Shift-and-Domain-Adaptation-in-Machine-Learning/"><![CDATA[<p>TL;DR: If you want the best <em>accuracy</em> on the target domain, you have to match the class frequency in the training set. You cannot affect the 
ROCAUC nor the PRAUC, but you can affect the accuracy. If you don’t know the target distribution at training time, you can measure the distribution during test time using only the features, calculate a weight-correction for every training example class, and retrain the model so it will match the newly detected distribution (See <a href="https://arxiv.org/abs/1802.03916" target="_blank">this</a> paper). Sometimes you can know the future target distribution. For example, if you predict a dice image classifier, you can expect the rolled dice to have a uniform distribution, and so you can balance the classes at training time.</p>

<h4 id="what-is-a-label-shift">What is a Label Shift?</h4>

<p>Label shift is when the classes/labels/targets distribution at deployment time (test time) is different from what you had in training time. For example, you train a cat/dog classifier using 1000 images of dogs and 4000 images of cats, because that’s the distribution of pets people have at home in France. However, when you deploy the model in Germany, when 50% of the people have cats and 50% have dogs, that’s a label shift.
In label shift, the target distribution is different, but the manifestation of targets as features remains the same. That means dogs in Germany look the same as dogs in France. You only have more dogs, but they are the same kind of dogs. If the dogs in Germany look different than dogs in France, that’s a different phenomenon, not a label shift.
More formally, if the source distribution is $p$ and target is $q$, the feature manifestation remains the same: $p(\boldsymbol{x}|y)=q(\boldsymbol{x}|y)$</p>

<h4 id="live-example">Live example</h4>

<p>We will try to see what happens in a label shift: we train a classifier on a source distribution of items, and check the performance when the distribution does not change, and when the distribution of true labels is changed.</p>

<p>Let’s say a basketball player’s average height is 180 cm, with stdev of 10 cm, and a football average height is 170 cm, with stdev of 10 cm. Assume this is true globally (in every country).</p>

<p>We collect a dataset of players in France, and our dataset contains 70% of basketball players, and 30% of football players. We will train the model and see the performance in France. Then, we will check the performance when we deploy the model in Germany, when we have 50% of basketball players, and 50% of football players.</p>

<p>Let’s write some code. First, a few helper functions.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># @title Click to Expand/Collapse
</span><span class="n">basketball_height</span><span class="p">,</span> <span class="n">basket_std</span> <span class="o">=</span> <span class="mi">180</span><span class="p">,</span> <span class="mi">10</span>
<span class="n">football_height</span><span class="p">,</span> <span class="n">football_std</span> <span class="o">=</span> <span class="mi">170</span><span class="p">,</span> <span class="mi">10</span>

<span class="n">dataset_length</span> <span class="o">=</span> <span class="mi">20000</span>
<span class="n">test_set_portion</span> <span class="o">=</span> <span class="mf">0.4</span>

<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">from</span> <span class="nn">sklearn.linear_model</span> <span class="kn">import</span> <span class="n">LogisticRegression</span>
<span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span>
<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">accuracy_score</span><span class="p">,</span> <span class="n">classification_report</span>
<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">accuracy_score</span>
<span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">accuracy_score</span><span class="p">,</span> <span class="n">classification_report</span><span class="p">,</span> <span class="n">roc_auc_score</span><span class="p">,</span> <span class="n">average_precision_score</span><span class="p">,</span> <span class="n">confusion_matrix</span>

<span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">44</span><span class="p">)</span>


<span class="k">def</span> <span class="nf">generate_dataset</span><span class="p">(</span><span class="n">dataset_length</span><span class="p">,</span> <span class="n">probability_of_0</span><span class="p">):</span>
  <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">choice</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">size</span><span class="o">=</span><span class="n">dataset_length</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="p">[</span><span class="n">probability_of_0</span><span class="p">,</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">probability_of_0</span><span class="p">])</span>
  <span class="n">num_ones</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">y</span><span class="p">)</span>
  <span class="c1"># print(f"#[0] == {len(y)-num_ones}, #[1] == {num_ones}")
</span>  <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">empty</span><span class="p">(</span><span class="n">dataset_length</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">float</span><span class="p">)</span>
  <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">dataset_length</span><span class="p">):</span>
    <span class="k">if</span> <span class="n">y</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">==</span><span class="mi">0</span><span class="p">:</span>
      <span class="n">x</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="n">basketball_height</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="n">basket_std</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
    <span class="k">else</span><span class="p">:</span>
      <span class="n">x</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="n">football_height</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="n">football_std</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span>
  <span class="n">X</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># make X a matrix and not a vector
</span>  <span class="k">return</span> <span class="n">X</span><span class="p">,</span><span class="n">y</span>


<span class="k">def</span> <span class="nf">print_metrics</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_proba</span><span class="p">,</span> <span class="n">threshold</span><span class="p">):</span>

    <span class="n">y_pred</span> <span class="o">=</span> <span class="p">(</span><span class="n">y_proba</span> <span class="o">&gt;=</span> <span class="n">threshold</span><span class="p">).</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span>

    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="se">\n</span><span class="s">Confusion Matrix - each row true class (percentage), at treshold of </span><span class="si">{</span><span class="n">threshold</span><span class="si">}</span><span class="s">:"</span><span class="p">)</span>
    <span class="n">cm</span> <span class="o">=</span> <span class="n">confusion_matrix</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span>
    <span class="n">cm_percentage</span> <span class="o">=</span> <span class="p">(</span><span class="n">cm</span> <span class="o">/</span> <span class="n">cm</span><span class="p">.</span><span class="nb">sum</span><span class="p">())</span> <span class="o">*</span> <span class="mi">100</span>  <span class="c1"># Normalize by the total number of samples
</span>    <span class="k">print</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">round</span><span class="p">(</span><span class="n">cm_percentage</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span>  <span class="c1"># Print with two decimal places
</span>
    <span class="c1"># print(f"Model accuracy: {accuracy_score(y_test, y_pred):.2f}")
</span>
    <span class="c1"># Print classification report
</span>    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="se">\n</span><span class="s">Classification Report at treshold </span><span class="si">{</span><span class="n">threshold</span><span class="si">}</span><span class="s">:"</span><span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="n">classification_report</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">))</span>

    <span class="c1"># Calculate and print ROC AUC
</span>    <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">y_test</span><span class="p">))</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span>  <span class="c1"># Ensure it's binary classification
</span>        <span class="n">roc_auc</span> <span class="o">=</span> <span class="n">roc_auc_score</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_proba</span><span class="p">)</span>
        <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="se">\n</span><span class="s">ROC AUC: </span><span class="si">{</span><span class="n">roc_auc</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

        <span class="c1"># Calculate and print PR AUC
</span>        <span class="n">pr_auc</span> <span class="o">=</span> <span class="n">average_precision_score</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_proba</span><span class="p">)</span>
        <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"PR AUC (Precision-Recall AUC): </span><span class="si">{</span><span class="n">pr_auc</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
    <span class="k">else</span><span class="p">:</span>
        <span class="k">print</span><span class="p">(</span><span class="s">"</span><span class="se">\n</span><span class="s">ROC AUC: Not applicable for multi-class classification"</span><span class="p">)</span>

<span class="k">def</span> <span class="nf">test_on_country_with_this_class_0_prob</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">probability_of_0</span><span class="p">,</span> <span class="n">threshold</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">False</span><span class="p">):</span>
  <span class="n">X</span><span class="p">,</span><span class="n">y</span> <span class="o">=</span> <span class="n">generate_dataset</span><span class="p">(</span><span class="n">dataset_length</span><span class="p">,</span> <span class="n">probability_of_0</span><span class="p">)</span>
  <span class="n">X_train</span><span class="p">,</span> <span class="n">X_test</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">y_test</span>  <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span><span class="n">X</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="n">test_set_portion</span><span class="p">)</span>
  <span class="k">if</span> <span class="n">train</span><span class="p">:</span>
    <span class="n">model</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">()</span>
    <span class="n">model</span><span class="p">.</span><span class="n">fit</span><span class="p">(</span><span class="n">X_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">)</span>

  <span class="n">y_proba</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">predict_proba</span><span class="p">(</span><span class="n">X_test</span><span class="p">)[:,</span> <span class="mi">1</span><span class="p">]</span>

  <span class="n">print_metrics</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">y_proba</span><span class="p">,</span> <span class="n">threshold</span><span class="p">)</span>
  <span class="k">if</span> <span class="n">train</span><span class="p">:</span>
    <span class="k">return</span> <span class="n">model</span>


<span class="k">def</span> <span class="nf">train_test_on_source_then_test_on_two_more_countries</span><span class="p">(</span><span class="n">source_prob</span><span class="p">,</span> <span class="n">target_probabilities</span><span class="p">,</span> <span class="n">threshold</span><span class="p">):</span>
  <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="se">\n</span><span class="s">**** country we train+test at class 0 ratio of </span><span class="si">{</span><span class="n">source_prob</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
  <span class="n">model</span> <span class="o">=</span> <span class="n">test_on_country_with_this_class_0_prob</span><span class="p">(</span><span class="n">model</span><span class="o">=</span><span class="bp">None</span><span class="p">,</span> <span class="n">probability_of_0</span> <span class="o">=</span> <span class="n">source_prob</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="n">threshold</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
  <span class="k">for</span> <span class="n">target_prob</span> <span class="ow">in</span> <span class="n">target_probabilities</span><span class="p">:</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="se">\n</span><span class="s">**** Test in a country with class 0 ratio of </span><span class="si">{</span><span class="n">target_prob</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
    <span class="n">test_on_country_with_this_class_0_prob</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">target_prob</span><span class="p">,</span> <span class="n">threshold</span><span class="p">)</span>
</code></pre></div></div>

<p>Now train + test on the SOURCE country, with 0 class frequency in population of 0.7, and test it also on 0.9 and 0.5:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">train_test_on_source_then_test_on_two_more_countries</span><span class="p">(</span><span class="n">source_prob</span><span class="o">=</span><span class="mf">0.7</span><span class="p">,</span> <span class="n">target_probabilities</span><span class="o">=</span><span class="p">[</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">],</span> <span class="n">threshold</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
</code></pre></div></div>

<p>And the results:</p>

<pre><code class="language-python2">**** country we train+test at class 0 ratio of 0.7

Confusion Matrix - each row true class (percentage), at treshold of 0.5:
[[63.3   6.42]
 [19.31 10.96]]

Classification Report at treshold 0.5:
              precision    recall  f1-score   support

           0       0.77      0.91      0.83      5578
           1       0.63      0.36      0.46      2422

    accuracy                           0.74      8000
   macro avg       0.70      0.63      0.65      8000
weighted avg       0.73      0.74      0.72      8000


ROC AUC: 0.76
PR AUC (Precision-Recall AUC): 0.58

**** test in a country with class 0 ratio of 0.9

Confusion Matrix - each row true class (percentage), at treshold of 0.5:
[[81.7   8.15]
 [ 6.68  3.48]]

Classification Report at treshold 0.5:
              precision    recall  f1-score   support

           0       0.92      0.91      0.92      7188
           1       0.30      0.34      0.32       812

    accuracy                           0.85      8000
   macro avg       0.61      0.63      0.62      8000
weighted avg       0.86      0.85      0.86      8000


ROC AUC: 0.76
PR AUC (Precision-Recall AUC): 0.30

**** test in a country with class 0 ratio of 0.5

Confusion Matrix - each row true class (percentage), at treshold of 0.5:
[[45.5   4.75]
 [31.15 18.6 ]]

Classification Report at treshold 0.5:
              precision    recall  f1-score   support

           0       0.59      0.91      0.72      4020
           1       0.80      0.37      0.51      3980

    accuracy                           0.64      8000
   macro avg       0.70      0.64      0.61      8000
weighted avg       0.69      0.64      0.61      8000


ROC AUC: 0.76
PR AUC (Precision-Recall AUC): 0.75
</code></pre>

<p>In the country with 0.9 basketball players, the accuracy has <em>increased</em> (75%-&gt;86%), the ROC AUC remained the same, but the PR AUC decreased.</p>

<p>In the target country, where we have an equal number of players 0.5, we can see that the accuracy is <em>lower</em> than the 1st country (75%-&gt;63%). But the ROC AUC is the same. The ROC AUC will remain the same no matter the threshold we choose, and no matter the probability_of_0_in_target we choose. Also, the PR AUC increased.</p>

<p>Why does accuracy change in domain shift? When we train a classifier, the 
classifier takes into account not only the relation between the features and the label, but also the class proportions in the distribution.</p>

<p>Now let’s try to <em>adapt</em> the model to the 0.9 target country:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">train_test_on_source_then_test_on_two_more_countries</span><span class="p">(</span><span class="n">source_prob</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">target_probabilities</span><span class="o">=</span><span class="p">[</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">],</span> <span class="n">threshold</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
</code></pre></div></div>

<pre><code class="language-python2">Confusion Matrix - each row true class (percentage), at treshold of 0.5:
[[89.46  0.32]
 [ 9.69  0.52]]

Classification Report at treshold 0.5:
              precision    recall  f1-score   support

           0       0.90      1.00      0.95      7183
           1       0.62      0.05      0.09       817

    accuracy                           0.90      8000
   macro avg       0.76      0.52      0.52      8000
weighted avg       0.87      0.90      0.86      8000


ROC AUC: 0.78
PR AUC (Precision-Recall AUC): 0.31

**** test in a country with class 0 ratio of 0.9

Confusion Matrix - each row true class (percentage), at treshold of 0.5:
[[89.64  0.26]
 [ 9.6   0.5 ]]

Classification Report at treshold 0.5:
              precision    recall  f1-score   support

           0       0.90      1.00      0.95      7192
           1       0.66      0.05      0.09       808

    accuracy                           0.90      8000
   macro avg       0.78      0.52      0.52      8000
weighted avg       0.88      0.90      0.86      8000


ROC AUC: 0.76
PR AUC (Precision-Recall AUC): 0.31

**** test in a country with class 0 ratio of 0.5

Confusion Matrix - each row true class (percentage), at treshold of 0.5:
[[49.74  0.15]
 [47.78  2.34]]

Classification Report at treshold 0.5:
              precision    recall  f1-score   support

           0       0.51      1.00      0.67      3991
           1       0.94      0.05      0.09      4009

    accuracy                           0.52      8000
   macro avg       0.72      0.52      0.38      8000
weighted avg       0.73      0.52      0.38      8000


ROC AUC: 0.76
PR AUC (Precision-Recall AUC): 0.76
</code></pre>

<p>We can see that instead of previous 86% accuracy, we now have 90% accuracy. Our model performed better in deployment time, since we adapted it to the new distribution. Label adaptation will produce the highest accuracy on target domain.</p>

<p>What will happen if we train our model using balanced classes (0.5)?</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">train_test_on_source_then_test_on_two_more_countries</span><span class="p">(</span><span class="n">source_prob</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">target_probabilities</span><span class="o">=</span><span class="p">[</span><span class="mf">0.9</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">],</span> <span class="n">threshold</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span>
</code></pre></div></div>

<p>And the results:</p>

<pre><code class="language-python2">**** country we train+test at class 0 ratio of 0.5

Confusion Matrix - each row true class (percentage), at treshold of 0.5:
[[34.35 14.96]
 [16.14 34.55]]

Classification Report at treshold 0.5:
              precision    recall  f1-score   support

           0       0.68      0.70      0.69      3945
           1       0.70      0.68      0.69      4055

    accuracy                           0.69      8000
   macro avg       0.69      0.69      0.69      8000
weighted avg       0.69      0.69      0.69      8000


ROC AUC: 0.76
PR AUC (Precision-Recall AUC): 0.76

**** test in a country with class 0 ratio of 0.9

Confusion Matrix - each row true class (percentage), at treshold of 0.5:
[[62.12 27.86]
 [ 3.16  6.85]]

Classification Report at treshold 0.5:
              precision    recall  f1-score   support

           0       0.95      0.69      0.80      7199
           1       0.20      0.68      0.31       801

    accuracy                           0.69      8000
   macro avg       0.57      0.69      0.55      8000
weighted avg       0.88      0.69      0.75      8000


ROC AUC: 0.76
PR AUC (Precision-Recall AUC): 0.30

**** test in a country with class 0 ratio of 0.5

Confusion Matrix - each row true class (percentage), at treshold of 0.5:
[[35.5  15.09]
 [15.18 34.24]]

Classification Report at treshold 0.5:
              precision    recall  f1-score   support

           0       0.70      0.70      0.70      4047
           1       0.69      0.69      0.69      3953

    accuracy                           0.70      8000
   macro avg       0.70      0.70      0.70      8000
weighted avg       0.70      0.70      0.70      8000


ROC AUC: 0.76
PR AUC (Precision-Recall AUC): 0.75
</code></pre>

<p>We can see that <em>both</em> the accuracy and the ROC AUC remained the same in the two different target distributions (~0.69), but the PR AUC changed.</p>

<p>That means that if you train your model on balanced classes, the model will perform with the same accuracy on any future target distribution. Not necessarily optimal accuracy, but constant.</p>

<h4 id="final-conclusions">Final Conclusions</h4>

<ol>
  <li>
    <p>If you want the accuracy (the trace of the confusion matrix) not to change during label shift, train with balanced classes. However, this constant accuracy comes with a price: The constant accuracy will be lower than if you train with the correct proportion as in the target domain.
If you want to achieve the highest accuracy on the target domain, you should train (in source domain) with the same class proportion as in the target domain. This is called domain-adaptation.</p>
  </li>
  <li>
    <p>Interesting to see that the PR AUC in a target domain only depends on the class-ratio in that domain, and it does <em>not</em> depend on the training ratio in the source domain. So PR AUC is not affected by class-balancing during training, you cannot fix it.</p>
  </li>
  <li>
    <p>ROC AUC does not change during label-shift, no matter what your training distribution is.</p>
  </li>
  <li>
    <p>When you move from domain A to B (with label shift assumption), while the ROC AUC stays the same, the accuracy may improve/worsen, and the PR AUC may improve/worsen, <em>but</em>, while you can’t affect the ROCAUC and PRAUC in the target domain, you <em>can</em> affect the accuracy, by retraining the classifier with the right proportion. Even if accuracy increases when moving from A to B, you can make it even higher by matching the label proportion.</p>
  </li>
</ol>]]></content><author><name>Nadav Benedek</name></author><summary type="html"><![CDATA[TL;DR: If you want the best accuracy on the target domain, you have to match the class frequency in the training set. You cannot affect the ROCAUC nor the PRAUC, but you can affect the accuracy. If you don’t know the target distribution at training time, you can measure the distribution during test time using only the features, calculate a weight-correction for every training example class, and retrain the model so it will match the newly detected distribution (See this paper). Sometimes you can know the future target distribution. For example, if you predict a dice image classifier, you can expect the rolled dice to have a uniform distribution, and so you can balance the classes at training time.]]></summary></entry><entry><title type="html">Memory Footprint of a Neural Net During Backpropagation</title><link href="https://nadavb.com/Memory-Footprint-of-Neural-Net/" rel="alternate" type="text/html" title="Memory Footprint of a Neural Net During Backpropagation" /><published>2024-04-11T14:22:00+03:00</published><updated>2024-04-11T14:22:00+03:00</updated><id>https://nadavb.com/Memory%20Footprint%20of%20Neural%20Net</id><content type="html" xml:base="https://nadavb.com/Memory-Footprint-of-Neural-Net/"><![CDATA[<p>In this article we discuss the memory footprint of a neural network during backpropagation, how backprop works, what affects the memory footprint, code to demonstrate the memory footprint, and more.</p>

<h2 id="backpropagation">Backpropagation</h2>

<p>Let’s have a look at a three layer network backpropagation. Let’s denote as $f(\cdot)$ a general function of the inner parameters, $x_1$ as the input to the first layer, $x_2$ the input to the second layer (and the output of the first layer), and $x_4$ the output of the network. So we have:
 $x_2=f(x1=input, w1) \quad|\quad  x_3=f(x2,w2)  \quad|\quad   x_4=f(x_3,w_3)  \quad|\quad L=f(x_4) $</p>

<p>Now, if we want to find the gradient of $w_1$ we have: $ \frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial x_4} \frac{\partial x_4}{\partial x_3}   \frac{\partial x_3}{\partial x_2}   \frac{\partial x_2}{\partial w_1}$</p>

<p>Let’s mark the parameters of $f()$ in bold when they are essential in the general case, and in light font when they are sometimes needed, depending on the actual function. For example, if $x_4=x_3 w_3$ then $ \frac{\partial x_4}{\partial x_3}=w_3$ therefore the derivative is <strong>only</strong> a function of $w_3$ and <strong>does not depend on the layer input activation $x_3$</strong>. However, if $x_4=\sigma (x_3 w_3)$ then the derivative depends on both variables, where $\sigma$ is sigmoid or ReLU. That’s what makes nonlinearity non-linear, after all. So, let’s look at all the terms in the chain rule, and for each one of them, analyze if it always depends on the input to the layer (x) or only sometimes:</p>

\[\newcommand{\mb}[1]{\mathbf{#1}}
\newcommand{\mi}[1]{\textit{#1}}

\frac{\partial L}{\partial w_1} =  \underbrace{  \frac{\partial L}{\partial x_4}}_{f(\mathbf{x_4})}     \underbrace{ \frac{\partial x_4}{\partial x_3} }_{f(x_3, \mb{w_3})}       \underbrace{\frac{\partial x_3}{\partial x_2} }_{f(x_2, \mb{w_2})}     \underbrace{ \frac{\partial x_2}{\partial w_1}}_{f(\mb{x_1}, w_1)}\]

<p><u>Observation 1</u>:  We can see that the first term (that represents the <strong>last</strong> layer), which is the impact of the output of the network to the loss, always depends on the output of the network, which is obvious. In the last term, which is the layer we want to optimize, the <strong>input to the layer is mandatory</strong>, unless the weight we’re interested in is the bias, for example, and then the derivative does not depend on the input.</p>

<p><u>Observation 2</u>:  What information do we need to store, before the backprop starts, in order to update $w_1$? We can see that in the <strong>general</strong> case, we need all layer inputs/outputs, meaning all $x_1 .. x_4$. That means that after the <strong>forward()</strong> pass, we must store all activations of the network: the input, the hidden representations, and the output (which is the input to the next layer). <strong>However</strong> in some cases, for example if one of the <strong>inner</strong> layers is a <strong>linear layer</strong> with no nonlinearity, we do not need to store the input for the layer during the forward pass. This will be demonstrated using the code below.</p>

<p>Combining the two observations, we can conclude that in the special case, where (1) we have a layer which is <strong>frozen</strong> (meaning that we do not want to optimize its weights) and; (2) the frozen layer is a <strong>linear layer</strong>; we can choose <strong>not</strong> to store the activation during forward pass, since it is unneeded for the optimization of the layer (since its frozen), and not needed for the update of upstream weights in the DAG/computation graph.</p>

<p>Furthermore, for efficiency, the <strong>backprop</strong> starts from the end, here are the steps:</p>

<ol>
  <li>
    <p>We first compute and hold as state $s_4 =  \underbrace{  \frac{\partial L}{\partial x_4}}_{f(\mathbf{x_4})} $ . Reminder: $x_4$ is the output of the network, so what we are calculating is the derivative of the loss function in respect to the model output. We need $x_4$ to calculate this gradient, but after we calculated it, we can release the activation $x_4$ from memory as we will not use it anymore.</p>
  </li>
  <li>
    <p>If the third layer is unfrozen, update the (last in the chain) weight \(\nabla w_3 = s_4  \underbrace{  \frac{\partial x_4}{\partial w_3}}_{f(\mb{x_3},w_3)}\), then compute $s_3 = s_4  \underbrace{ \frac{\partial x_4}{\partial x_3} }_{f(x_3, \mb{w_3})}  $ , now we can release the activation $x_3$ from memory. We can see that if a layer is both frozen and linear, we do not use the activation $x_3$ at all, and in this case we do not need to store it in first place.</p>
  </li>
  <li>
    <p>If the second layer is unfrozen, update the weight \(\nabla w_2 = s_3 \underbrace{  \frac{\partial x_3}{\partial w_2}}_{f(\mb{x_2},w_2)}\). Note that we needed for this calculation both the temporary gradient flow tha arrived backward from the next layer, and also the input activation to this layer. You can think of it as follows: we need the information from all sides, both the input information and the feedback from the output channel.
Compute \(s_2 = s_3  \underbrace{\frac{\partial x_3}{\partial x_2} }_{f(x_2, \mb{w_2})}\), now we can release the activation $x_2$ from memory.</p>
  </li>
  <li>
    <p>If the first layer is unfrozen, update the weight \(\nabla w_1 = s_2  \underbrace{ \frac{\partial x_2}{\partial w_1}}_{f(\mb{x_1}, w_1)}\).  We can release the input to the network $x_1$ and we’re done.</p>
  </li>
</ol>

<p><u>Observation 3</u>: We do not need to hold input activations to a layer which is frozen and for which all the upstream weights (the ancestors DAG weights) are frozen too, since no one will use the computation of $s$.</p>

<p><strong>To conclude</strong>:</p>

<p>Activation memory allocation: In cases where a layer is frozen AND (the derivative of its output w.r.t the input does not depend on the input, as in the linear case, OR all upstream dependant weights are frozen too), we can save memory and not store the input activation.</p>

<p>Activation memory deallocation: during backprop, we can release the activations we’ve already used, to free memory.</p>

<h2 id="network-memory-footprint">Network Memory Footprint</h2>

<p>In PyTorch <strong>training loop</strong>, we have five basic steps:</p>

<ol>
  <li>
    <p><strong>Load</strong> the network to memory. If the network has 100M parameters and we use 32bit float per parameter, it will take 400MB.</p>
  </li>
  <li>
    <p>Compute the <strong>$\mi{foward()}$</strong> pass, and store some activations, depending on the conclusions above. Activations are stored for each sample in a batch, therefore the memory footprint depends on the <strong>batch size</strong>. If we train using mixed-precision, the forward activations are kept in 16bit instead of 32bit, so the footprint reduces by half.</p>
  </li>
  <li>
    <p>Compute the $\mi{backward()}$ pass, allocate gradient storage per unfrozen parameters, use the activations we’ve stored to compute the gradients of unfrozen layers, and <strong>release</strong> unneeded activations, as we go backward. In this process we calculate two types of gradients: gradients w.r.t. weights, which are stored in param.grad, and intermiate gradients w.r.t. the input activations which needed internally to continue the backprop and freed when not needed.</p>
  </li>
  <li>
    <p>Running the optimizer for unfronzen layers $\mi{optimizer.step()}$, uses the gradients we calculated, store and update internal moments/optimizer state only for unfrozen layers. Batch size does not effect the memory allocation of the optimizer, since all gradients are summed in place, and when GPU is used, cores work in parallel to update the .grad of the tensors. If we use Adam optimizer, two moments will be kept for each parameter.</p>
  </li>
  <li>
    <p>Release the gradients we’ve accumulated using $\mi{zero_grad()}$</p>
  </li>
</ol>

<h2 id="demonstration-code">Demonstration Code</h2>

<p>Run this code:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">torch</span>
<span class="kn">import</span> <span class="nn">torch.nn</span> <span class="k">as</span> <span class="n">nn</span>
<span class="k">def</span> <span class="nf">test_memory</span><span class="p">(</span><span class="n">in_size</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">out_size</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">200</span><span class="p">,</span> <span class="n">freeze_start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">freeze_end</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span>
                <span class="n">hidden_size</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">optimizer_type</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span> <span class="n">Adam</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
                <span class="n">device</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">add_relu</span><span class="o">=</span><span class="bp">True</span><span class="p">):</span>

  <span class="n">sample_input</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">randn</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">in_size</span><span class="p">)</span>

  <span class="n">layers</span> <span class="o">=</span> <span class="p">[</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">)]</span>
  <span class="k">for</span> <span class="n">layer_index</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">):</span>
    <span class="n">layers_to_append</span> <span class="o">=</span> <span class="p">[</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">,</span> <span class="n">bias</span><span class="o">=</span><span class="bp">False</span><span class="p">)]</span>
    <span class="k">if</span> <span class="n">add_relu</span><span class="p">:</span>
      <span class="n">layers_to_append</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">ReLU</span><span class="p">())</span>

    <span class="c1"># Selectively freeze some layers
</span>    <span class="k">if</span> <span class="n">freeze_start</span> <span class="o">&lt;=</span> <span class="n">layer_index</span> <span class="o">&lt;</span> <span class="n">freeze_end</span><span class="p">:</span>
      <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">layers_to_append</span><span class="p">:</span>
        <span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">layer</span><span class="p">.</span><span class="n">parameters</span><span class="p">():</span>
          <span class="n">param</span><span class="p">.</span><span class="n">requires_grad</span> <span class="o">=</span> <span class="bp">False</span>

    <span class="n">layers</span><span class="p">.</span><span class="n">extend</span><span class="p">(</span><span class="n">layers_to_append</span><span class="p">)</span>

  <span class="n">layers</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">hidden_size</span><span class="p">,</span> <span class="n">out_size</span><span class="p">))</span>
  <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"number of layers: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">layers</span><span class="p">)</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
  <span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="p">.</span><span class="n">Sequential</span><span class="p">(</span><span class="o">*</span><span class="n">layers</span><span class="p">)</span>

  <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer_type</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="p">.</span><span class="mi">001</span><span class="p">)</span>
  <span class="n">start</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">memory_allocated</span> <span class="p">(</span><span class="n">device</span><span class="p">)</span>
  <span class="k">print</span><span class="p">(</span><span class="s">"Starting at 0 memory usage as baseline."</span><span class="p">)</span>
  <span class="n">model</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>
  <span class="n">after_model</span> <span class="o">=</span>  <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">memory_allocated</span> <span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="o">-</span> <span class="n">start</span>
  <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"1: After model to device: </span><span class="si">{</span><span class="n">after_model</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
  <span class="k">print</span><span class="p">(</span><span class="s">""</span><span class="p">)</span>
  <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">):</span>
    <span class="k">print</span><span class="p">(</span><span class="s">"Iteration"</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>

    <span class="n">a</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">memory_allocated</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>  <span class="o">-</span> <span class="n">start</span>

    <span class="c1"># Running the forward pass. Here all activations will be saved, 
</span>    <span class="c1"># per every sample in batch
</span>    <span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">sample_input</span><span class="p">.</span><span class="n">to</span><span class="p">(</span><span class="n">device</span><span class="p">)).</span><span class="nb">sum</span><span class="p">()</span>
    <span class="n">b</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">memory_allocated</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="o">-</span> <span class="n">start</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"2: Memory consumed after forward pass (activations stored, depends on batch size): </span><span class="si">{</span><span class="n">b</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s"> change: "</span><span class="p">,</span> <span class="sa">f</span><span class="s">'</span><span class="si">{</span><span class="n">b</span> <span class="o">-</span> <span class="n">a</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s">'</span> <span class="p">)</span>  <span class="c1"># batch * num layers * hidden_size * 4 bytes per float
</span>
    <span class="c1"># Backward step: Here we allocate (unless already allocated) 
</span>    <span class="c1"># and store the gradient of each non-frozen parameter,
</span>    <span class="c1"># and we release/discard the activations which are descendants in the DAG as we go.
</span>    <span class="c1"># So at the end the change in memory = +non-frozen parameters (if was unallocated) - non-degenerate activations
</span>    <span class="c1"># gradients are accumulated in place in the .grad attribute of the tensors 
</span>    <span class="c1"># for which gradients are being computed. Each GPU core works on a different
</span>    <span class="c1"># part of the .grad tensor, so they can all work in parallel
</span>    <span class="n">out</span><span class="p">.</span><span class="n">backward</span><span class="p">()</span>
    <span class="n">c</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">memory_allocated</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="o">-</span> <span class="n">start</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"3: After backward pass (activations released, grad stored) </span><span class="si">{</span><span class="n">c</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s"> change: </span><span class="si">{</span><span class="n">c</span><span class="o">-</span><span class="n">b</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>

    <span class="c1"># Running the optimizer, at the first time, will store 2 moments for each non-frozen parameter (if using Adam), which will be kept throughout the training
</span>    <span class="c1"># So change in memory, in the first time = 2 * non-frozen parameters
</span>    <span class="c1"># optimizer changes the model parameters in place
</span>    <span class="n">optimizer</span><span class="p">.</span><span class="n">step</span><span class="p">()</span>
    <span class="n">d</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">memory_allocated</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>  <span class="o">-</span> <span class="n">start</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"4: After optimizer step (moments stored at first time): </span><span class="si">{</span><span class="n">d</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s"> change: </span><span class="si">{</span><span class="n">d</span><span class="o">-</span><span class="n">c</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s"> "</span> <span class="p">)</span>

    <span class="c1"># zero_grad = Reset and release gradients tensors created in .backward()
</span>    <span class="n">model</span><span class="p">.</span><span class="n">zero_grad</span><span class="p">()</span>
    <span class="n">e</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">memory_allocated</span><span class="p">(</span><span class="n">device</span><span class="p">)</span>  <span class="o">-</span> <span class="n">start</span>
    <span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"5: After zero_grad step (grads released): </span><span class="si">{</span><span class="n">e</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s"> change: </span><span class="si">{</span><span class="n">e</span><span class="o">-</span><span class="n">d</span><span class="si">:</span><span class="p">,</span><span class="si">}</span><span class="s"> "</span> <span class="p">)</span>
    <span class="k">print</span><span class="p">(</span><span class="s">""</span><span class="p">)</span>

<span class="n">test_memory</span><span class="p">(</span><span class="n">optimizer_type</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="n">Adam</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">freeze_start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">freeze_end</span><span class="o">=</span><span class="mi">0</span>
            <span class="p">,</span> <span class="n">add_relu</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
</code></pre></div></div>

<p>Let’s have a look at the second iteration, for example:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="n">Iteration</span> <span class="mi">2</span>
<span class="mi">2</span><span class="p">:</span> <span class="n">Memory</span> <span class="n">consumed</span> <span class="n">after</span> <span class="n">forward</span> <span class="k">pass</span> <span class="p">(</span><span class="n">activations</span> <span class="n">stored</span><span class="p">,</span> <span class="n">depends</span> <span class="n">batch</span> <span class="n">size</span><span class="p">):</span> <span class="mi">46</span><span class="p">,</span><span class="mi">616</span><span class="p">,</span><span class="mi">576</span> <span class="n">change</span><span class="p">:</span>  <span class="mi">5</span><span class="p">,</span><span class="mi">171</span><span class="p">,</span><span class="mi">200</span>
<span class="mi">3</span><span class="p">:</span> <span class="n">After</span> <span class="n">backward</span> <span class="k">pass</span> <span class="p">(</span><span class="n">activations</span> <span class="n">released</span><span class="p">,</span> <span class="n">grad</span> <span class="n">stored</span><span class="p">)</span> <span class="mi">49</span><span class="p">,</span><span class="mi">580</span><span class="p">,</span><span class="mi">544</span> <span class="n">change</span><span class="p">:</span> <span class="mi">2</span><span class="p">,</span><span class="mi">963</span><span class="p">,</span><span class="mi">968</span>
<span class="mi">4</span><span class="p">:</span> <span class="n">After</span> <span class="n">optimizer</span> <span class="n">step</span> <span class="p">(</span><span class="n">moments</span> <span class="n">stored</span> <span class="n">at</span> <span class="n">first</span> <span class="n">time</span><span class="p">):</span> <span class="mi">49</span><span class="p">,</span><span class="mi">580</span><span class="p">,</span><span class="mi">544</span> <span class="n">change</span><span class="p">:</span> <span class="mi">0</span> 
<span class="mi">5</span><span class="p">:</span> <span class="n">After</span> <span class="n">zero_grad</span> <span class="n">step</span> <span class="p">(</span><span class="n">grads</span> <span class="n">released</span><span class="p">):</span> <span class="mi">41</span><span class="p">,</span><span class="mi">445</span><span class="p">,</span><span class="mi">376</span> <span class="n">change</span><span class="p">:</span> <span class="o">-</span><span class="mi">8</span><span class="p">,</span><span class="mi">135</span><span class="p">,</span><span class="mi">168</span> 

</code></pre></div></div>

<p>What’s going on? The forward() pass allocated 5M of activations memory. You can play with the batch_size and see how it affects the activation memory size. You can change freeze_end=200 and see the activation memory drops, however if you then set add_relu=True you can see the memory footprint goes up again, since the layers are not linear anymore, as we proved above.</p>

<p>The backward() allocated 8M for the gradients (like the unfrozen model size), but released 5M of activations. So at the end we see a net increase of 3M.</p>

<p>The optimizer step allocated nothing in this iteration, since it already allocated 16M at the first iteration, which is exactly two moments per each parameter in the model.</p>

<p>And the zero_grad released the 8M of gradients.</p>

<h2 id="peak-memory-consumption">Peak Memory Consumption</h2>

<p>So what is the peak memory consumption for a network? We have two places where we could potentially reach peak memory consumption.</p>

<ol>
  <li>After forward(): model + 2 x model (for Adam optimizer) + activations (batch dependant)</li>
  <li>After backward(): model + 2 x model (for Adam optimizer) + gradients (non-frozen model params)</li>
</ol>

<p>So the question depends on which component is more dominant in the specific network: the activations or the gradients? For example, in CNNs we may have a big activation space, even for a small parameter. In Transformers, the activations depend on the sequence length. Activations are batch dependent while gradient memory footprint only depends on the size of the weights.</p>

<p>So we can phrase the peak memory consumption:</p>

<p>model + 2 x model (for Adam optimizer) + MAX( gradients (non-frozen model params, can be multiplied by two if more than one gpu in training, for accumulation), activations [batch size * activations_per_batch * activation_precision]  )</p>

<h2 id="methods-to-reduce-memory-footprint">Methods to reduce memory footprint</h2>

<h3 id="gradient-accumulation">Gradient Accumulation</h3>

<p>We can do the optimizer.step() and optimizer.zero_grad() steps only once in a while, and essentially split a batch to sub batches. As seen above, this will reduce the activation memory allocation, but will not reduce the gradient memory allocation.</p>

<h3 id="gradient-checkpointing">Gradient Checkpointing</h3>

<p>This actually should have been called Activation Checkpointing. Instead of storing all activations which were computed in the forward() pass, only store a <strong>subset</strong> of them, thus reducing memory footprint, and re-compute the missing activations on the fly, only when needed during the backprop() computation.</p>

\[\square\]]]></content><author><name>Nadav Benedek</name></author><category term="training,memory,neural" /><category term="net," /><category term="dnn," /><category term="backward," /><category term="batch" /><summary type="html"><![CDATA[In this article we discuss the memory footprint of a neural network during backpropagation, how backprop works, what affects the memory footprint, code to demonstrate the memory footprint, and more.]]></summary></entry><entry><title type="html">PRILoRA: Pruned and Rank-Increasing Low-Rank Adaptation</title><link href="https://nadavb.com/PRILoRA/" rel="alternate" type="text/html" title="PRILoRA: Pruned and Rank-Increasing Low-Rank Adaptation" /><published>2024-01-21T05:20:00+02:00</published><updated>2024-01-21T05:20:00+02:00</updated><id>https://nadavb.com/PRILoRA</id><content type="html" xml:base="https://nadavb.com/PRILoRA/"><![CDATA[<p>With the proliferation of large pre-trained language models (PLMs), fine-tuning all model parameters becomes increasingly inefficient, particularly when dealing with numerous downstream tasks that entail substantial training and storage costs. Several approaches aimed at achieving parameter-efficient fine-tuning (PEFT) have been proposed. Among them, Low-Rank Adaptation (LoRA) stands out as an archetypal method, incorporating trainable rank decomposition matrices into each target module. Nevertheless, LoRA does not consider the varying importance of each layer. To address these challenges, we introduce PRILoRA, which linearly allocates a different rank for each layer, in an increasing manner, and performs pruning throughout the training process, considering both the temporary magnitude of weights and the accumulated statistics of the input to any given layer. We validate the effectiveness of PRILoRA through extensive experiments on eight GLUE benchmarks, setting a new state of the art.</p>

<p>See the full <a href="https://arxiv.org/abs/2401.11316">paper</a>.</p>]]></content><author><name>Nadav Benedek</name></author><category term="LLM," /><category term="LoRA," /><category term="finetuning" /><summary type="html"><![CDATA[With the proliferation of large pre-trained language models (PLMs), fine-tuning all model parameters becomes increasingly inefficient, particularly when dealing with numerous downstream tasks that entail substantial training and storage costs. Several approaches aimed at achieving parameter-efficient fine-tuning (PEFT) have been proposed. Among them, Low-Rank Adaptation (LoRA) stands out as an archetypal method, incorporating trainable rank decomposition matrices into each target module. Nevertheless, LoRA does not consider the varying importance of each layer. To address these challenges, we introduce PRILoRA, which linearly allocates a different rank for each layer, in an increasing manner, and performs pruning throughout the training process, considering both the temporary magnitude of weights and the accumulated statistics of the input to any given layer. We validate the effectiveness of PRILoRA through extensive experiments on eight GLUE benchmarks, setting a new state of the art.]]></summary></entry><entry><title type="html">Feature-preprocessing/engineering leakage during data-preparation and Train-Test Split Strategy Protocol</title><link href="https://nadavb.com/Feature-Preprocessing-Leakage-During-Data-Preparation/" rel="alternate" type="text/html" title="Feature-preprocessing/engineering leakage during data-preparation and Train-Test Split Strategy Protocol" /><published>2023-08-05T11:27:00+03:00</published><updated>2023-08-05T11:27:00+03:00</updated><id>https://nadavb.com/Feature-Preprocessing%20Leakage%20During%20Data-Preparation</id><content type="html" xml:base="https://nadavb.com/Feature-Preprocessing-Leakage-During-Data-Preparation/"><![CDATA[<p>Are we allowed to transform the input data in any way we want? Can we train sub-models to preprocess features? Can we use a pipeline of models? Can we use the output of one model as an input of another model?</p>

<p>Assume we have a supervised learning problem, and we would like to preprocess a feature using a separate supervised model.</p>

<h4 id="minimal-problem-example">Minimal problem example:</h4>

<p>We would like to predict a 0-1 label (boolean), using 2 features: 1 numerical and 1 textual feature, using a decision tree (call it model A). Since decision trees use numbers, we would like to take the textual feature and <em>transform</em> it, using a separate supervised model that takes a textual input, predicts a scalar label (call it model B), and use this model to convert the textual feature into a numerical feature, so that we will have 2 numerical features, and then use a decision tree to predict the 0-1 label (boolean).</p>

<p>The question is if this process is legit. And if it is legit, are there any restrictions on the process to make it legit?</p>

<p>To make it more specific, can we first train model B, any way we want, and then transform the feature and train model A? Can we do a train/test split anyway we want (randomly) during training model B, and then do a train/test split (randomly) during training model A? Or must the split be the same during training of model B and model A? If this requirement is needed, it can be a bit complicated in real life scenarios, where you need to enforce the same train/test split procedure in all ML teams in an organization involved in the project.</p>

<p>Let’s make the problem even more simple: Assume that the textual feature is just a random string, meaning that it carries zero information in it, and that the numerical feature is also random and has no correlation with the label. Assume we have 1000 examples.</p>

<p>So, we train model B in a very overfitted manner, meaning that the training accuracy (800 examples) is 100%, and test accuracy (200 examples) is 50% (random guessing). This can happen when the model memorizes all the random texts that correspond with 0-labels and all the texts that correspond with 1-labels. That means that transforming the textual feature using model B will convert the training set features into the training-labels themselves.</p>

<p>Now, let’s say that in model A training, 100% of the examples in the test-set (200 of 200) are actually train-set examples of model B (because we did a new random train/test split). As for the training-set, 600 out of the 800 are examples that were part of the training-set of model B. They are completely overfitted, so they (the features) contain the label itself. The other 200 are random and have no correlation to the label. So training A will yield a model that simply uses the overfitted feature, generated by model B, to predict the label. Therefore, the test accuracy of model A will be 100%, although the features have zero information in them, because the transformed feature predicts exactly the label.</p>

<style>
figure {
  display: flex;        /* Use flexbox to center content */
  flex-direction: column; /* Stack image and caption vertically */
  align-items: center;  /* Center content horizontally */
  margin: 20px auto;    /* Add margin for spacing */
}

figcaption {
  text-align: center;   /* Center the caption text */
  font-size: 0.9em;     /* Optional: Adjust font size */
  color: #fff;          /* Optional: Change caption color */
  margin-top: 10px;     /* Add space between the image and caption */
  max-width: 100%;      /* Ensure long captions fit within the container */
}
</style>

<figure>
  <img src="/assets/leakage/feature_leakage.png" alt="Feature Leakage Diagram" />
  <figcaption>Figure 1: At first, we have a data set with completely random features (R). Model B is used for preprocessing features, and overfits on the training set heavily. Model B training accuracy is 100% and the test accuracy is random (50%). Then, Model A uses the engineered features, with a different train/test split, and achieves 100% test accuracy.</figcaption>
</figure>

<p>This is an example of a case where features are completely random, but we reached 100% test accuracy. We could have tweaked the example to reach any test performance we wanted.</p>

<p>I call this “Preprocessing Leakage”. If we would have kept the train/test split identical across model B and model A, the problem would have been avoided.</p>

<p>Mitigation: One way to solve this preprocessing leakage is to <em>avoid a random train/test split</em>, but rather do the split deterministically, using a stable hash function over the examples. For example, split by the hash of the user id, account id, etc., so that all sub-models will have the same train/test split.
This also means having full control over the train/test split, and avoiding using different third party libraries that each can split the data in a different way.</p>

<p>Another way, if it’s possible, is to do the split before the training of any model. This is not always possible in an organization which has a feature store where many teams insert new features to the databases, and sometimes insert trained-features into the databases. Having full control on the way every team trains the feature-models can be very difficult. In many cases, there are many features you don’t even know the meaning of, let alone the way they were injected, using some trainable model you’re not responsible for.</p>

<p>The train/test split is important, and we need to keep it under control.</p>

<h1 id="train-test-split-strategy-protocol">Train-Test Split Strategy Protocol</h1>

<p><strong>Seed-Stable</strong>: Sometimes, when the dataset is not too large, you don’t want a three-way split: train/dev/test, in order to avoid losing data for the training set. However, when you split into only train+test, you risk overfitting the test_set. 
So, after you fixed the test_set and measured the metrics, you can take the all_set, split it differently using a different random seed to a different train/test split, train the model again and evaluate on the new test_set. If the two evaluations, each using a different random seed, are pretty much the same, we call it <strong>seed-stable</strong>. 
You can do more re-trainings, using different seed-splits, to increase the confidence in the seed-stability, and then it is quite similar to k-fold Cross Validation or Monte Carlo Cross-Validation. In any case, when a few ml-teams are using a shared dataset to develop models for sub-tasks or pipeline of models in which a model output is used as an input to another model, <strong>all published models</strong> must not use any examples from the shared global test_set for their <strong>final training</strong> which is used to publish the model.</p>

<p><strong>Access Safety</strong>: On the one hand, we don’t want training loops to accidentally have access to a folder or a database containing both the train+test examples, to avoid mistakes that a model trains on a test example. On the other hand, often examples do reside in one folder/database, and all_set keeps growing when we have more examples.
So, a solution is to hold a file called split.json which will hold the split, and all the training procedures (in various ML teams) will have access to this file, which will point to the examples/files in the database/directories.</p>

<p><strong>Who is generating the split.json file?</strong> It is generated by a program which receives a folder or a database of unsplit examples, and creates the split_file. If the file does not exist, it randomly splits the examples in the folder and saves the file. If the file does exist, it runs: (a) <strong>validation procedure</strong>, to make sure the file is valid: no overlapping, no less pointer, no more pointers, etc., and (b) <strong>update procedure</strong>.
Every sub-algorithm, sub-model or derived dataset must use the global split file in order to make the train dataset and the test dataset, and run the validation procedure.</p>

<p>Each ml-team can check that the model is seed-stable on a different seed-split of the all_set, to make sure the model does not overfit the test_set, or alternatively use a different seed-split as develop_set, but their <strong>published</strong> model <strong>must</strong> use the global split as defined by the split_file.</p>

<p>Observation: in some companies, where you have different ml-teams working on different models which use the same data and may interact with each other, they must agree on a shared split_file, or alternatively have an external coordinator that specifies this split_file for them. The interaction of models can be in the form of a sequential-chain, that is the output of a model is the input of another model, or models that work in parallel, each working on a sub-task. Example: Image taken from an autonomous car, where one model segments the objects which are cars, and another model classifies the car into models.</p>

<p><strong>How to make sure the ml-teams do not ‘cheat’?</strong> A team can cheat and overfit to the test_set. It is difficult for the organization to detect it. The only way to overcome this is to completely hide the global test_set from the teams. However, this means the teams will have less data to work with.</p>

<p><strong>Save split inside model</strong>: When we are given a published model file (architecture+weights), how can we tell what images are we allowed to evaluate it on, in case it comes from a different ml-team? 
When saving a model to disk, save the set of filenames/pointers of examples used for training or the set of hashes of examples/features used for training, as part of the model state:  model.split_context (register_buffer).
When loading a model from disk, make sure the test_set you plan to validate the model on, and the model.split_context <strong>do not</strong> overlap. This way, you can be more sure that the evaluation you are doing is solid and correct. 
Optionally, if the model.split_context is less than the current global split training set, you can print a warning message that the model could be retrained with more training data.</p>

<p><strong>What happens when the pre-split dataset grows (the all_set)?</strong> We want the <em>existing</em> models to remain <strong>valid</strong>. That means, an old example must <strong>stick</strong> to its previous train/test affiliation. A train example cannot move to be a test example, otherwise the model will overperform. A test example can become a training example theoretically, but I don’t see a necessity for this to happen. That means, when we re-split, we cannot do it with a random-seed split, but rather <strong>preserve</strong> the previous split affiliation, and for each new example, in some predefined probability, assign it to either train or test. When the dataset grows, we have the flexibility to change this probability, if for example we want to increase the ratio of the training set size to test size.</p>

<p><strong>Dataset Derivatives:</strong> Sometimes, an ml-team needs to make derivatives on the dataset. For example: The dataset contains images of a front camera of a car. Model A does instance segmentation of cars with 2 classes (background/foreground), and model B uses the 512x512 bounding box to classify the cars with 100 classes. Team B would like to create a dataset of cars, and this dataset is extracted from the root dataset, which makes it a <strong>derivative</strong>. Now, Model B stands by itself: It can be evaluated standalone, regardless of model A. However, it can also be evaluated and used in tandem, as part of the full pipeline, which is finding cars in the image (model A) and classifying them (model B). So, in the two scenarios you want all the safety measures to be in place. It means that the split_file should include a pointer to the <strong>parent</strong> split_file, and the model.saved_training_set should also include the set of training images used in the parent split_file. This will allow you to safely evaluate the model both on the parent dataset and the derivative dataset.</p>

<p><strong>Model evaluation principles:</strong> Sometimes, when you use data augmentation for training, you may decide to use augmentations for the test_set as well. For example, if your test_set is not big enough and you want to enlarge it. Or, if you want to make sure the model’s generalization can withstand the transformation. In those cases, where the evaluation process involves <strong>randomness</strong>, you need to make sure you <strong>set the RNG seed</strong> before the evaluation, for the evaluation to be consistent during different RNG states, either on the same computer or across computers. However, this may drastically ruin the training process, as when you return from the evaluation procedure, the next training batch starts from the same RNG state. Therefore, before setting the evaluation seed, you must save the RNG state, and restore it at the end:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Save the current RNG states, before doing the evaluation, after a few training cycles
</span><span class="n">rng_state_torch</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">get_rng_state</span><span class="p">()</span>
<span class="n">rng_state_cuda</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">get_rng_state_all</span><span class="p">()</span> <span class="k">if</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">is_available</span><span class="p">()</span> <span class="k">else</span> <span class="bp">None</span>
<span class="n">rng_state_random</span> <span class="o">=</span> <span class="n">random</span><span class="p">.</span><span class="n">getstate</span><span class="p">()</span>
<span class="n">rng_state_numpy</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">get_state</span><span class="p">()</span>
<span class="n">seed_everything</span><span class="p">(</span><span class="n">dice_classifier_train_seed</span><span class="p">)</span>

<span class="n">model</span><span class="p">.</span><span class="nb">eval</span><span class="p">()</span>
<span class="n">eval_loss</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="n">full_labels</span><span class="p">,</span> <span class="n">full_outputs</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span>
<span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="n">no_grad</span><span class="p">():</span>
   <span class="k">for</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="ow">in</span> <span class="n">test_loader</span><span class="p">:</span>
      <span class="c1"># …
</span>
<span class="c1"># Restore the original RNG states, to allow the training randomness needed
</span><span class="n">torch</span><span class="p">.</span><span class="n">set_rng_state</span><span class="p">(</span><span class="n">rng_state_torch</span><span class="p">)</span>
<span class="k">if</span> <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">is_available</span><span class="p">():</span>
   <span class="n">torch</span><span class="p">.</span><span class="n">cuda</span><span class="p">.</span><span class="n">set_rng_state_all</span><span class="p">(</span><span class="n">rng_state_cuda</span><span class="p">)</span>
<span class="n">random</span><span class="p">.</span><span class="n">setstate</span><span class="p">(</span><span class="n">rng_state_random</span><span class="p">)</span>
<span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">set_state</span><span class="p">(</span><span class="n">rng_state_numpy</span><span class="p">)</span>
</code></pre></div></div>

<p>If you want to be more sure about your evaluation results:</p>

<p>(1) Do a sanity check, and sometimes run the evaluation procedure twice, and make sure you get the exact same results. This helps to validate that the RNG seeds were properly set and that nothing was forgotten.</p>

<p>(2) Make sure that the evaluation which happens every few steps during training is identical to the evaluation results after fresh loading of the model from the disk. Meaning the last evaluation loop in training, which happens before the model was saved to disk, should be identical to the evaluation of a loaded model, without training at all.</p>

<p>To help maintain consistency in evaluation, you want to make sure the data-loader has <em>persistent_workers=False</em>, so that every evaluation cycle will start from the exact same RNG state, with no leftovers from the previous evaluation cycle:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="k">def</span> <span class="nf">worker_init_fn</span><span class="p">(</span><span class="n">worker_id</span><span class="p">):</span>
	<span class="n">seed</span> <span class="o">=</span> <span class="mi">42</span> <span class="o">+</span> <span class="n">worker_id</span>
	<span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
	<span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
	<span class="n">torch</span><span class="p">.</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span>
	
<span class="n">data_loader_test</span> <span class="o">=</span> <span class="n">DataLoader</span><span class="p">(</span><span class="n">dataset_test</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>
	                              <span class="n">num_workers</span><span class="o">=</span><span class="n">num_workers</span><span class="p">,</span>
	                              <span class="n">persistent_workers</span><span class="o">=</span><span class="bp">False</span><span class="p">,</span>  <span class="c1"># persistent_workers must be False for consistent results!!
</span>								  <span class="n">collate_fn</span><span class="o">=</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="nb">tuple</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">x</span><span class="p">)),</span> <span class="n">worker_init_fn</span><span class="o">=</span><span class="n">worker_init_fn</span><span class="p">,</span> <span class="n">generator</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">Generator</span><span class="p">().</span><span class="n">manual_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">))</span>
								  
</code></pre></div></div>

<p>Rarely, the evaluation procedure contains a .train() part. This can happen when you use some model libraries that return certain metrics and results (like loss) only when in .train() mode. If this is the case, and the model uses BatchNorm, you must save and restore BatchNorm state, before and after the .train() part, otherwise the .train() part in your evaluation procedure will have side effects on the training process of the model:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">save_batchnorm_state</span><span class="p">(</span><span class="n">model</span><span class="p">):</span>
	<span class="n">bn_states</span> <span class="o">=</span> <span class="p">[]</span>
	<span class="k">for</span> <span class="n">module</span> <span class="ow">in</span> <span class="n">model</span><span class="p">.</span><span class="n">modules</span><span class="p">():</span>
		<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">BatchNorm2d</span><span class="p">,</span> <span class="n">nn</span><span class="p">.</span><span class="n">BatchNorm1d</span><span class="p">)):</span>
			<span class="n">bn_states</span><span class="p">.</span><span class="n">append</span><span class="p">({</span>
				<span class="s">"running_mean"</span><span class="p">:</span> <span class="n">module</span><span class="p">.</span><span class="n">running_mean</span><span class="p">.</span><span class="n">clone</span><span class="p">(),</span>
				<span class="s">"running_var"</span><span class="p">:</span> <span class="n">module</span><span class="p">.</span><span class="n">running_var</span><span class="p">.</span><span class="n">clone</span><span class="p">(),</span>
				<span class="s">"num_batches_tracked"</span><span class="p">:</span> <span class="n">module</span><span class="p">.</span><span class="n">num_batches_tracked</span><span class="p">.</span><span class="n">clone</span><span class="p">()</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="s">"num_batches_tracked"</span><span class="p">)</span> <span class="k">else</span> <span class="bp">None</span>
			<span class="p">})</span>
	<span class="k">return</span> <span class="n">bn_states</span>


<span class="c1"># Function to restore BatchNorm states
</span><span class="k">def</span> <span class="nf">restore_batchnorm_state</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">bn_states</span><span class="p">):</span>
	<span class="n">i</span> <span class="o">=</span> <span class="mi">0</span>
	<span class="k">for</span> <span class="n">module</span> <span class="ow">in</span> <span class="n">model</span><span class="p">.</span><span class="n">modules</span><span class="p">():</span>
		<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">module</span><span class="p">,</span> <span class="p">(</span><span class="n">nn</span><span class="p">.</span><span class="n">BatchNorm2d</span><span class="p">,</span> <span class="n">nn</span><span class="p">.</span><span class="n">BatchNorm1d</span><span class="p">)):</span>
			<span class="n">module</span><span class="p">.</span><span class="n">running_mean</span><span class="p">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">bn_states</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="s">"running_mean"</span><span class="p">])</span>
			<span class="n">module</span><span class="p">.</span><span class="n">running_var</span><span class="p">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">bn_states</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="s">"running_var"</span><span class="p">])</span>
			<span class="k">if</span> <span class="n">bn_states</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="s">"num_batches_tracked"</span><span class="p">]</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
				<span class="n">module</span><span class="p">.</span><span class="n">num_batches_tracked</span><span class="p">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">bn_states</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="s">"num_batches_tracked"</span><span class="p">])</span>
			<span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span>
			

</code></pre></div></div>

<h1 id="the-full-code">The full code</h1>

<p>Here is a class I created, DatasetSplitter, that implements the concepts specified above:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">os</span><span class="p">,</span> <span class="n">torch</span><span class="p">,</span> <span class="n">json</span><span class="p">,</span> <span class="n">hashlib</span><span class="p">,</span> <span class="n">random</span>
<span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span>
<span class="k">class</span> <span class="nc">DatasetSplitter</span><span class="p">:</span>
	
	<span class="o">@</span><span class="nb">staticmethod</span>
	<span class="k">def</span> <span class="nf">_calculate_file_hash</span><span class="p">(</span><span class="n">file_path</span><span class="p">):</span>
		<span class="s">"""Calculates a 6-character alphanumeric hash for a given file."""</span>
		<span class="n">hasher</span> <span class="o">=</span> <span class="n">hashlib</span><span class="p">.</span><span class="n">md5</span><span class="p">()</span>
		<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">file_path</span><span class="p">,</span> <span class="s">'rb'</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span>
			<span class="n">buf</span> <span class="o">=</span> <span class="n">f</span><span class="p">.</span><span class="n">read</span><span class="p">()</span>
			<span class="n">hasher</span><span class="p">.</span><span class="n">update</span><span class="p">(</span><span class="n">buf</span><span class="p">)</span>
		<span class="k">return</span> <span class="n">hasher</span><span class="p">.</span><span class="n">hexdigest</span><span class="p">()[:</span><span class="mi">6</span><span class="p">]</span>
	
	<span class="o">@</span><span class="nb">staticmethod</span>
	<span class="k">def</span> <span class="nf">validation</span><span class="p">(</span><span class="n">all_files</span><span class="p">,</span> <span class="n">split_file_path</span><span class="p">):</span>
		<span class="s">"""
		Validate that split_file_path file do not have same entry for train and test
		Validate that all_files matches the files in the split file.
		"""</span>
		<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">split_file_path</span><span class="p">,</span> <span class="s">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="nb">file</span><span class="p">:</span> <span class="n">split_data_loaded</span> <span class="o">=</span> <span class="n">json</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="nb">file</span><span class="p">)</span>
		
		<span class="n">train_set</span><span class="p">,</span> <span class="n">test_set</span> <span class="o">=</span> <span class="n">split_data_loaded</span><span class="p">[</span><span class="s">'train'</span><span class="p">],</span> <span class="n">split_data_loaded</span><span class="p">[</span><span class="s">'test'</span><span class="p">]</span>
		<span class="n">all_files_in_split</span> <span class="o">=</span> <span class="p">{</span><span class="n">entry</span><span class="p">[</span><span class="s">'path'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">train_set</span> <span class="o">+</span> <span class="n">test_set</span><span class="p">}</span>
		
		<span class="k">def</span> <span class="nf">find_duplicates</span><span class="p">(</span><span class="n">input_list</span><span class="p">,</span> <span class="n">error_message</span><span class="p">):</span>
			<span class="n">seen</span><span class="p">,</span> <span class="n">duplicates</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(),</span> <span class="nb">set</span><span class="p">()</span>
			<span class="k">for</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">input_list</span><span class="p">:</span>
				<span class="k">if</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">seen</span><span class="p">:</span> <span class="n">duplicates</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">item</span><span class="p">)</span>  <span class="c1"># Add to duplicates if already seen
</span>				<span class="k">else</span><span class="p">:</span> <span class="n">seen</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">item</span><span class="p">)</span>  <span class="c1"># Mark as seen
</span>			<span class="k">if</span> <span class="n">duplicates</span><span class="p">:</span>
				<span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s">"</span><span class="si">{</span><span class="n">error_message</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="s">', '</span><span class="p">.</span><span class="n">join</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="nb">str</span><span class="p">,</span> <span class="n">duplicates</span><span class="p">))</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
				
		<span class="c1"># validations
</span>		<span class="n">find_duplicates</span><span class="p">([</span><span class="n">entry</span><span class="p">[</span><span class="s">'path'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">train_set</span><span class="p">],</span> <span class="s">"Train set contains identical path"</span><span class="p">)</span>
		<span class="n">find_duplicates</span><span class="p">([</span><span class="n">entry</span><span class="p">[</span><span class="s">'hash'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">train_set</span><span class="p">],</span> <span class="s">"Train set contains identical hash"</span><span class="p">)</span>
		<span class="n">find_duplicates</span><span class="p">([</span><span class="n">entry</span><span class="p">[</span><span class="s">'path'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">test_set</span><span class="p">],</span> <span class="s">"Test set contains identical path"</span><span class="p">)</span>
		<span class="n">find_duplicates</span><span class="p">([</span><span class="n">entry</span><span class="p">[</span><span class="s">'hash'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">test_set</span><span class="p">],</span> <span class="s">"Test set contains identical hash"</span><span class="p">)</span>
		<span class="n">find_duplicates</span><span class="p">([</span><span class="n">entry</span><span class="p">[</span><span class="s">'path'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">train_set</span><span class="p">]</span><span class="o">+</span><span class="p">[</span><span class="n">entry</span><span class="p">[</span><span class="s">'path'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">test_set</span><span class="p">],</span> <span class="s">"Train and test sets overlap in filename path"</span><span class="p">)</span>
		<span class="n">find_duplicates</span><span class="p">([</span><span class="n">entry</span><span class="p">[</span><span class="s">'hash'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">train_set</span><span class="p">]</span> <span class="o">+</span> <span class="p">[</span><span class="n">entry</span><span class="p">[</span><span class="s">'hash'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">test_set</span><span class="p">],</span> <span class="s">"Train and test sets overlap in hash"</span><span class="p">)</span>
		
		
		<span class="c1"># Ensure all files exist and have matching hashes
</span>		<span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">train_set</span> <span class="o">+</span> <span class="n">test_set</span><span class="p">:</span>
			<span class="n">file_path</span><span class="p">,</span> <span class="n">file_hash</span> <span class="o">=</span> <span class="n">entry</span><span class="p">[</span><span class="s">'path'</span><span class="p">],</span> <span class="n">entry</span><span class="p">[</span><span class="s">'hash'</span><span class="p">]</span>
			<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="p">.</span><span class="n">path</span><span class="p">.</span><span class="n">exists</span><span class="p">(</span><span class="n">file_path</span><span class="p">):</span>
				<span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s">"File </span><span class="si">{</span><span class="n">file_path</span><span class="si">}</span><span class="s"> listed in split file does not exist!"</span><span class="p">)</span>
			<span class="n">actual_hash</span> <span class="o">=</span> <span class="n">DatasetSplitter</span><span class="p">.</span><span class="n">_calculate_file_hash</span><span class="p">(</span><span class="n">file_path</span><span class="p">)</span>
			<span class="k">if</span> <span class="n">actual_hash</span> <span class="o">!=</span> <span class="n">file_hash</span><span class="p">:</span>
				<span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s">"Hash mismatch for file </span><span class="si">{</span><span class="n">file_path</span><span class="si">}</span><span class="s">: expected </span><span class="si">{</span><span class="n">file_hash</span><span class="si">}</span><span class="s">, got </span><span class="si">{</span><span class="n">actual_hash</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
		
		<span class="c1"># Ensure no extra files in dataset folder
</span>		<span class="n">actual_files_in_folder</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">all_files</span><span class="p">)</span>
		<span class="k">if</span> <span class="n">extra_files</span> <span class="p">:</span><span class="o">=</span> <span class="n">actual_files_in_folder</span> <span class="o">-</span> <span class="n">all_files_in_split</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">split_file_path</span><span class="p">]):</span>
			<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"DataSplitter.validation Warning: The following files are in the folder but not in the split file: </span><span class="si">{</span><span class="n">extra_files</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
	
	<span class="o">@</span><span class="nb">staticmethod</span>
	<span class="k">def</span> <span class="nf">folder_to_list_of_files</span><span class="p">(</span><span class="n">dataset_folder</span><span class="p">):</span>
		<span class="k">return</span> <span class="p">[</span>
			<span class="nb">str</span><span class="p">(</span><span class="nb">file</span><span class="p">.</span><span class="n">resolve</span><span class="p">())</span>
			<span class="k">for</span> <span class="nb">file</span> <span class="ow">in</span> <span class="n">Path</span><span class="p">(</span><span class="n">dataset_folder</span><span class="p">).</span><span class="n">glob</span><span class="p">(</span><span class="s">'**/*'</span><span class="p">)</span>
			<span class="k">if</span> <span class="nb">file</span><span class="p">.</span><span class="n">is_file</span><span class="p">()</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">file</span><span class="p">.</span><span class="n">name</span><span class="p">.</span><span class="n">endswith</span><span class="p">(</span><span class="s">'.json'</span><span class="p">)</span>
		<span class="p">]</span>
	<span class="o">@</span><span class="nb">staticmethod</span>
	<span class="k">def</span> <span class="nf">update</span><span class="p">(</span><span class="n">list_of_files</span><span class="p">,</span> <span class="n">split_file_path</span><span class="p">,</span> <span class="n">train_size</span><span class="p">):</span>
		
		<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">split_file_path</span><span class="p">,</span> <span class="s">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="nb">file</span><span class="p">:</span> <span class="n">split_data</span> <span class="o">=</span> <span class="n">json</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="nb">file</span><span class="p">)</span>
		
		<span class="n">train_set</span> <span class="o">=</span> <span class="n">split_data</span><span class="p">[</span><span class="s">'train'</span><span class="p">]</span>
		<span class="n">test_set</span> <span class="o">=</span> <span class="n">split_data</span><span class="p">[</span><span class="s">'test'</span><span class="p">]</span>
		
		<span class="n">set_train_hash</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([</span><span class="n">entry</span><span class="p">[</span><span class="s">'hash'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">train_set</span><span class="p">])</span>
		<span class="n">set_train_path</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([</span><span class="n">entry</span><span class="p">[</span><span class="s">'path'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">train_set</span><span class="p">])</span>
		<span class="n">set_test_hash</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([</span><span class="n">entry</span><span class="p">[</span><span class="s">'hash'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">test_set</span><span class="p">])</span>
		<span class="n">set_test_path</span> <span class="o">=</span> <span class="nb">set</span><span class="p">([</span><span class="n">entry</span><span class="p">[</span><span class="s">'path'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">test_set</span><span class="p">])</span>
		
		<span class="n">all_files_in_split</span> <span class="o">=</span> <span class="p">{</span><span class="n">entry</span><span class="p">[</span><span class="s">'path'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">train_set</span> <span class="o">+</span> <span class="n">test_set</span><span class="p">}</span>
		
		<span class="c1"># Add new files to train or test sets
</span>		<span class="n">actual_files_in_folder</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">list_of_files</span><span class="p">)</span>
		<span class="n">new_files</span> <span class="o">=</span> <span class="n">actual_files_in_folder</span> <span class="o">-</span> <span class="n">all_files_in_split</span> <span class="o">-</span> <span class="nb">set</span><span class="p">([</span><span class="n">split_file_path</span><span class="p">])</span>
		
		<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Found </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">new_files</span><span class="p">)</span><span class="si">}</span><span class="s"> files in the folder which are not in the train set or the test set."</span><span class="p">)</span>
		
		<span class="k">for</span> <span class="n">file_path</span> <span class="ow">in</span> <span class="n">new_files</span><span class="p">:</span>
			<span class="n">file_hash</span> <span class="o">=</span> <span class="n">DatasetSplitter</span><span class="p">.</span><span class="n">_calculate_file_hash</span><span class="p">(</span><span class="n">file_path</span><span class="p">)</span>
			
			<span class="c1"># first, make sure the path and hash of file is not in train_set or train_set
</span>			<span class="k">if</span> <span class="n">file_path</span> <span class="ow">in</span> <span class="n">set_train_path</span><span class="p">:</span>
				<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"WARNING, new file </span><span class="si">{</span><span class="n">file_path</span><span class="si">}</span><span class="s"> in already in train_set, skipping."</span><span class="p">)</span>
				<span class="k">continue</span>
			<span class="k">if</span> <span class="n">file_path</span> <span class="ow">in</span> <span class="n">set_test_path</span><span class="p">:</span>
				<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"WARNING, new file </span><span class="si">{</span><span class="n">file_path</span><span class="si">}</span><span class="s"> in already in test_set, skipping."</span><span class="p">)</span>
				<span class="k">continue</span>
			<span class="k">if</span> <span class="n">file_hash</span> <span class="ow">in</span> <span class="n">set_train_hash</span><span class="p">:</span>
				<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"WARNING, hash </span><span class="si">{</span><span class="n">file_hash</span><span class="si">}</span><span class="s"> of new file </span><span class="si">{</span><span class="n">file_path</span><span class="si">}</span><span class="s"> in already in train_set, skipping."</span><span class="p">)</span>
				<span class="k">continue</span>
			<span class="k">if</span> <span class="n">file_hash</span> <span class="ow">in</span> <span class="n">set_test_hash</span><span class="p">:</span>
				<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"WARNING, hash </span><span class="si">{</span><span class="n">file_hash</span><span class="si">}</span><span class="s"> of new file </span><span class="si">{</span><span class="n">file_path</span><span class="si">}</span><span class="s"> in already in test_set, skipping."</span><span class="p">)</span>
				<span class="k">continue</span>
				
				
			
			<span class="k">if</span> <span class="n">random</span><span class="p">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">train_size</span><span class="p">:</span>
				<span class="n">train_set</span><span class="p">.</span><span class="n">append</span><span class="p">({</span><span class="s">"path"</span><span class="p">:</span> <span class="n">file_path</span><span class="p">,</span> <span class="s">"hash"</span><span class="p">:</span> <span class="n">file_hash</span><span class="p">})</span>
				<span class="n">set_train_path</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">file_path</span><span class="p">)</span>
				<span class="n">set_train_hash</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">file_hash</span><span class="p">)</span>
				<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Added </span><span class="si">{</span><span class="n">file_path</span><span class="si">}</span><span class="s"> to train set."</span><span class="p">)</span>
			<span class="k">else</span><span class="p">:</span>
				<span class="n">test_set</span><span class="p">.</span><span class="n">append</span><span class="p">({</span><span class="s">"path"</span><span class="p">:</span> <span class="n">file_path</span><span class="p">,</span> <span class="s">"hash"</span><span class="p">:</span> <span class="n">file_hash</span><span class="p">})</span>
				<span class="n">set_test_path</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">file_path</span><span class="p">)</span>
				<span class="n">set_test_hash</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">file_hash</span><span class="p">)</span>
				<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Added </span><span class="si">{</span><span class="n">file_path</span><span class="si">}</span><span class="s"> to test set."</span><span class="p">)</span>
		
		<span class="n">split_data</span><span class="p">[</span><span class="s">'train'</span><span class="p">]</span> <span class="o">=</span> <span class="n">train_set</span>
		<span class="n">split_data</span><span class="p">[</span><span class="s">'test'</span><span class="p">]</span> <span class="o">=</span> <span class="n">test_set</span>
		
		<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">split_file_path</span><span class="p">,</span> <span class="s">'w'</span><span class="p">)</span> <span class="k">as</span> <span class="nb">file</span><span class="p">:</span> <span class="n">json</span><span class="p">.</span><span class="n">dump</span><span class="p">(</span><span class="n">split_data</span><span class="p">,</span> <span class="nb">file</span><span class="p">,</span> <span class="n">indent</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
	
	<span class="o">@</span><span class="nb">staticmethod</span>
	<span class="k">def</span> <span class="nf">create_or_update_root_split_file</span><span class="p">(</span><span class="n">all_files</span><span class="p">,</span> <span class="n">split_file_path</span><span class="p">,</span> <span class="n">train_size</span><span class="p">):</span>
		<span class="s">"""
		If split_file_path does not exist, takes the dataset_folder, scan all the files in it, shuffle the list, and split the files into
		train and test sets, according to the train_size portion. Then, it creates the split_file_path json file, and save in it the two sets:
		the train set and the test set. All the file names in the split file should have absolute path.
		
		If the split_file_path already exists, it starts by running a separate validation() function: Read the file, check that the train set and the test
		does not overlap. If they do, it raises an exception. It also makes sure all the files in the train and test sets are in the dataset_folder.
		If they are not, it raises an exception. It also checks that there are no additional files in the folder that are not in the split file.
		If there are, output a warning message that the split file is can be updated.
		Then, it runs the update() function: For any NEW file in the folder, that is not in the split file already (either in train or test), it randomly
		assign it to train set with probability of train_size, and to test set with probability of 1 - train_size and prints a message explaining was
		is the new file affiliation. Then it saves the updated split file.
		
		In general, for any file in the split_file, besides the full path of the file, also include a 6 alphanumeric hash of the file content, so that
		if the filename is changed in the future, it's signature will be preserved. In the validation() function, also make sure each hash in the split_file,
		matches the true hash of the file in the folder.
		
		split_file should include additinal field: "parent" that is the path to the parent split file, if it exists. If it does not exist, it should be null.
		"""</span>
		<span class="k">if</span> <span class="n">os</span><span class="p">.</span><span class="n">path</span><span class="p">.</span><span class="n">exists</span><span class="p">(</span><span class="n">split_file_path</span><span class="p">):</span>
			<span class="c1"># Validate and update existing split file
</span>			<span class="n">DatasetSplitter</span><span class="p">.</span><span class="n">validation</span><span class="p">(</span><span class="n">all_files</span><span class="p">,</span> <span class="n">split_file_path</span><span class="p">)</span>
			<span class="n">DatasetSplitter</span><span class="p">.</span><span class="n">update</span><span class="p">(</span><span class="n">all_files</span><span class="p">,</span> <span class="n">split_file_path</span><span class="p">,</span> <span class="n">train_size</span><span class="p">)</span>
		<span class="k">else</span><span class="p">:</span>
			
			<span class="n">random</span><span class="p">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">all_files</span><span class="p">)</span>
			<span class="n">split_point</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">all_files</span><span class="p">)</span> <span class="o">*</span> <span class="n">train_size</span><span class="p">)</span>
			<span class="n">train_files</span> <span class="o">=</span> <span class="n">all_files</span><span class="p">[:</span><span class="n">split_point</span><span class="p">]</span>
			<span class="n">test_files</span> <span class="o">=</span> <span class="n">all_files</span><span class="p">[</span><span class="n">split_point</span><span class="p">:]</span>
			
			<span class="n">split_data</span> <span class="o">=</span> <span class="p">{</span>
				<span class="s">"parent"</span><span class="p">:</span> <span class="bp">None</span><span class="p">,</span>
				<span class="s">"train"</span><span class="p">:</span> <span class="p">[{</span><span class="s">"path"</span><span class="p">:</span> <span class="nb">file</span><span class="p">,</span> <span class="s">"hash"</span><span class="p">:</span> <span class="n">DatasetSplitter</span><span class="p">.</span><span class="n">_calculate_file_hash</span><span class="p">(</span><span class="nb">file</span><span class="p">)}</span> <span class="k">for</span> <span class="nb">file</span> <span class="ow">in</span> <span class="n">train_files</span><span class="p">],</span>
				<span class="s">"test"</span><span class="p">:</span> <span class="p">[{</span><span class="s">"path"</span><span class="p">:</span> <span class="nb">file</span><span class="p">,</span> <span class="s">"hash"</span><span class="p">:</span> <span class="n">DatasetSplitter</span><span class="p">.</span><span class="n">_calculate_file_hash</span><span class="p">(</span><span class="nb">file</span><span class="p">)}</span> <span class="k">for</span> <span class="nb">file</span> <span class="ow">in</span> <span class="n">test_files</span><span class="p">],</span>
			<span class="p">}</span>
			
			<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">split_file_path</span><span class="p">,</span> <span class="s">'w'</span><span class="p">)</span> <span class="k">as</span> <span class="nb">file</span><span class="p">:</span>
				<span class="n">json</span><span class="p">.</span><span class="n">dump</span><span class="p">(</span><span class="n">split_data</span><span class="p">,</span> <span class="nb">file</span><span class="p">,</span> <span class="n">indent</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
				<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Created new split file at </span><span class="si">{</span><span class="n">split_file_path</span><span class="si">}</span><span class="s">."</span><span class="p">)</span>
	
	<span class="o">@</span><span class="nb">staticmethod</span>
	<span class="k">def</span> <span class="nf">create_split_file_from_splitted_lists</span><span class="p">(</span><span class="n">list_files_train</span><span class="p">,</span> <span class="n">list_files_test</span><span class="p">,</span> <span class="n">split_file_path</span><span class="p">,</span> <span class="n">parent_split_file</span><span class="p">):</span>
		<span class="s">"""
		If the file split_file_path does not exist, it will create it using the two lists. It will make sure the lists do not overlap.
		If the parent_split_file is not None, run the validation() function on the parent, and all ancestors.
		"""</span>
		
		<span class="c1"># Ensure no overlap between train and test sets
</span>		<span class="n">train_set</span><span class="p">,</span> <span class="n">test_set</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">list_files_train</span><span class="p">),</span> <span class="nb">set</span><span class="p">(</span><span class="n">list_files_test</span><span class="p">)</span>
		<span class="k">if</span> <span class="n">train_set</span> <span class="o">&amp;</span> <span class="n">test_set</span><span class="p">:</span> <span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="s">"Train and test sets overlap!"</span><span class="p">)</span>
		
		<span class="c1"># Validate the parent split file if provided
</span>		<span class="k">if</span> <span class="n">parent_split_file</span><span class="p">:</span>
			<span class="n">current_parent</span> <span class="o">=</span> <span class="n">parent_split_file</span>
			<span class="k">while</span> <span class="n">current_parent</span><span class="p">:</span>
				<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="p">.</span><span class="n">path</span><span class="p">.</span><span class="n">exists</span><span class="p">(</span><span class="n">current_parent</span><span class="p">):</span> <span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s">"Parent split file not found: </span><span class="si">{</span><span class="n">current_parent</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
				
				<span class="n">DatasetSplitter</span><span class="p">.</span><span class="n">validation</span><span class="p">(</span><span class="n">all_files</span><span class="o">=</span><span class="n">DatasetSplitter</span><span class="p">.</span><span class="n">folder_to_list_of_files</span><span class="p">(</span><span class="n">os</span><span class="p">.</span><span class="n">path</span><span class="p">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">current_parent</span><span class="p">)),</span>
				                           <span class="n">split_file_path</span><span class="o">=</span><span class="n">current_parent</span><span class="p">)</span>
				
				<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">current_parent</span><span class="p">,</span> <span class="s">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="n">parent_file</span><span class="p">:</span> <span class="n">parent_data</span> <span class="o">=</span> <span class="n">json</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="n">parent_file</span><span class="p">)</span>
				<span class="n">current_parent</span> <span class="o">=</span> <span class="n">parent_data</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="s">"parent"</span><span class="p">)</span>
				
		<span class="n">split_data</span> <span class="o">=</span> <span class="p">{</span>  <span class="c1"># Create the split file content
</span>			<span class="s">"parent"</span><span class="p">:</span> <span class="n">parent_split_file</span><span class="p">,</span>
			<span class="s">"train"</span><span class="p">:</span> <span class="p">[{</span><span class="s">"path"</span><span class="p">:</span> <span class="nb">file</span><span class="p">,</span> <span class="s">"hash"</span><span class="p">:</span> <span class="n">DatasetSplitter</span><span class="p">.</span><span class="n">_calculate_file_hash</span><span class="p">(</span><span class="nb">file</span><span class="p">)}</span> <span class="k">for</span> <span class="nb">file</span> <span class="ow">in</span> <span class="n">list_files_train</span><span class="p">],</span>
			<span class="s">"test"</span><span class="p">:</span> <span class="p">[{</span><span class="s">"path"</span><span class="p">:</span> <span class="nb">file</span><span class="p">,</span> <span class="s">"hash"</span><span class="p">:</span> <span class="n">DatasetSplitter</span><span class="p">.</span><span class="n">_calculate_file_hash</span><span class="p">(</span><span class="nb">file</span><span class="p">)}</span> <span class="k">for</span> <span class="nb">file</span> <span class="ow">in</span> <span class="n">list_files_test</span><span class="p">],</span>
		<span class="p">}</span>
		<span class="c1"># Save the split file
</span>		<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">split_file_path</span><span class="p">,</span> <span class="s">'w'</span><span class="p">)</span> <span class="k">as</span> <span class="nb">file</span><span class="p">:</span> <span class="n">json</span><span class="p">.</span><span class="n">dump</span><span class="p">(</span><span class="n">split_data</span><span class="p">,</span> <span class="nb">file</span><span class="p">,</span> <span class="n">indent</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>
		<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"Split file created at </span><span class="si">{</span><span class="n">split_file_path</span><span class="si">}</span><span class="s">."</span><span class="p">)</span>
		
	<span class="o">@</span><span class="nb">staticmethod</span>
	<span class="k">def</span> <span class="nf">add_split_context_to_model_before_save</span><span class="p">(</span><span class="n">split_filepath</span><span class="p">,</span> <span class="n">model</span><span class="p">):</span>
		<span class="s">"""
		This will add a field/buffer to a model (not a Parameter), called split_context, which is a list. The first element in the list
		is the content of the json object in the split file. If the split_file has a parent, including the parent content as the second element, and
		so on. This field will allow users of the model to make sure they do not evaluate the model on an example which is included in the training set,
		of the split_file or its ancestors.
		
		This is an example of how you should use it:
		
		DatasetSplitter.add_split_context_to_model_before_save(split_filepath, model)
		torch.save(model.state_dict(), model_save_path)
		"""</span>
		<span class="n">split_context</span> <span class="o">=</span> <span class="p">[]</span>
		
		<span class="c1"># Traverse the split file hierarchy
</span>		<span class="n">current_split_filepath</span> <span class="o">=</span> <span class="n">split_filepath</span>
		<span class="k">while</span> <span class="n">current_split_filepath</span><span class="p">:</span>
			<span class="c1"># Load the current split file
</span>			<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="p">.</span><span class="n">path</span><span class="p">.</span><span class="n">exists</span><span class="p">(</span><span class="n">current_split_filepath</span><span class="p">):</span>
				<span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s">"Split file not found: </span><span class="si">{</span><span class="n">current_split_filepath</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
			<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">current_split_filepath</span><span class="p">,</span> <span class="s">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="nb">file</span><span class="p">:</span> <span class="n">split_data</span> <span class="o">=</span> <span class="n">json</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="nb">file</span><span class="p">)</span>
			<span class="n">split_context</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">split_data</span><span class="p">)</span>
			
			<span class="c1"># Move to the parent split file, if it exists
</span>			<span class="n">current_split_filepath</span> <span class="o">=</span> <span class="n">split_data</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="s">"parent"</span><span class="p">)</span>
		
		<span class="c1"># Serialize split_context as JSON and register as a tensor buffer
</span>		<span class="n">serialized_context</span> <span class="o">=</span> <span class="n">json</span><span class="p">.</span><span class="n">dumps</span><span class="p">(</span><span class="n">split_context</span><span class="p">)</span>
		<span class="n">context_tensor</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">tensor</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">serialized_context</span><span class="p">.</span><span class="n">encode</span><span class="p">()),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">uint8</span><span class="p">)</span>
		<span class="n">model</span><span class="p">.</span><span class="n">register_buffer</span><span class="p">(</span><span class="s">"split_context"</span><span class="p">,</span> <span class="n">context_tensor</span><span class="p">)</span>
	
	<span class="o">@</span><span class="nb">staticmethod</span>
	<span class="k">def</span> <span class="nf">get_list_of_train_files_and_test_files</span><span class="p">(</span><span class="n">split_filepath</span><span class="p">,</span> <span class="n">compare_to_this_total_list</span><span class="p">):</span>
		<span class="s">"""
		Load the split file, and return the list of train files and test files.
		If  compare_to_this_total_list is not None, it will validate that the union of test_files and train_files is equal to compare_to_this_total_list
		"""</span>
		<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="p">.</span><span class="n">path</span><span class="p">.</span><span class="n">exists</span><span class="p">(</span><span class="n">split_filepath</span><span class="p">):</span> <span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s">"Split file not found: </span><span class="si">{</span><span class="n">split_filepath</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
		<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">split_filepath</span><span class="p">,</span> <span class="s">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="nb">file</span><span class="p">:</span> <span class="n">split_data</span> <span class="o">=</span> <span class="n">json</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="nb">file</span><span class="p">)</span>
		<span class="n">train_files</span> <span class="o">=</span> <span class="p">[</span><span class="n">entry</span><span class="p">[</span><span class="s">'path'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">split_data</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="s">'train'</span><span class="p">,</span> <span class="p">[])]</span>
		<span class="n">test_files</span> <span class="o">=</span> <span class="p">[</span><span class="n">entry</span><span class="p">[</span><span class="s">'path'</span><span class="p">]</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">split_data</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="s">'test'</span><span class="p">,</span> <span class="p">[])]</span>
		
		<span class="k">if</span> <span class="n">compare_to_this_total_list</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
			<span class="n">all_files</span> <span class="o">=</span> <span class="n">train_files</span> <span class="o">+</span> <span class="n">test_files</span>
			<span class="c1"># Compare the two sets
</span>			<span class="k">if</span> <span class="nb">set</span><span class="p">(</span><span class="n">all_files</span><span class="p">)</span> <span class="o">!=</span> <span class="nb">set</span><span class="p">(</span><span class="n">compare_to_this_total_list</span><span class="p">):</span>
				<span class="n">missing_from_split</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">compare_to_this_total_list</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">all_files</span><span class="p">)</span>  <span class="c1"># Files in folder but not in split file
</span>				<span class="n">extra_in_split</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">all_files</span><span class="p">)</span> <span class="o">-</span> <span class="nb">set</span><span class="p">(</span><span class="n">compare_to_this_total_list</span><span class="p">)</span>  <span class="c1"># Files in split file but not in folder
</span>				<span class="k">raise</span> <span class="nb">Exception</span><span class="p">(</span>
					<span class="sa">f</span><span class="s">"Split file [</span><span class="si">{</span><span class="n">split_filepath</span><span class="si">}</span><span class="s">] can be updated. Differences:</span><span class="se">\n</span><span class="s">"</span>
					<span class="sa">f</span><span class="s">"Missing from split file: </span><span class="si">{</span><span class="n">missing_from_split</span><span class="si">}</span><span class="se">\n</span><span class="s">"</span>
					<span class="sa">f</span><span class="s">"Extra in split file: </span><span class="si">{</span><span class="n">extra_in_split</span><span class="si">}</span><span class="s">"</span>
				<span class="p">)</span>
		
		<span class="k">return</span> <span class="n">train_files</span><span class="p">,</span> <span class="n">test_files</span>
	
	<span class="o">@</span><span class="nb">staticmethod</span>
	<span class="k">def</span> <span class="nf">validate_model_after_load</span><span class="p">(</span><span class="n">split_filepath</span><span class="p">,</span> <span class="n">loaded_model</span><span class="p">):</span>
		<span class="s">"""
		Here we want to make sure that the test_set which is described in the split_filepath, or any of its ancestors (union) does not overlap
		with any of the loaded_model.split_context content, both in respect to the filenames, and to the hashes.
		This is an example of how you should load a model:
		
		loaded_state_dict = torch.load(model_path, map_location=device, weights_only=True)
		model.register_buffer("split_context", loaded_state_dict["split_context"])
		model.load_state_dict(loaded_state_dict)
		DatasetSplitter.validate_model_after_load(split_filepath, model)
		
		"""</span>
		<span class="c1"># Deserialize split_context from the model
</span>		<span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">loaded_model</span><span class="p">,</span> <span class="s">"split_context"</span><span class="p">):</span> <span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="s">"Loaded model does not have a `split_context` attribute."</span><span class="p">)</span>
		
		<span class="c1"># Deserialize the split_context tensor into a Python object
</span>		<span class="n">serialized_context</span> <span class="o">=</span> <span class="nb">bytes</span><span class="p">(</span><span class="n">loaded_model</span><span class="p">.</span><span class="n">split_context</span><span class="p">.</span><span class="n">tolist</span><span class="p">()).</span><span class="n">decode</span><span class="p">()</span>  <span class="c1"># Convert tensor to bytes, then decode
</span>		<span class="n">split_context</span> <span class="o">=</span> <span class="n">json</span><span class="p">.</span><span class="n">loads</span><span class="p">(</span><span class="n">serialized_context</span><span class="p">)</span>  <span class="c1"># Deserialize JSON string back to a list of dictionaries
</span>		<span class="k">del</span> <span class="n">loaded_model</span>
		
		<span class="c1"># Load the split file and its ancestors into a unified test set
</span>		<span class="n">current_split_filepath</span> <span class="o">=</span> <span class="n">split_filepath</span>
		<span class="n">all_test_files</span><span class="p">,</span> <span class="n">all_test_hashes</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(),</span> <span class="nb">set</span><span class="p">()</span>
		
		<span class="k">while</span> <span class="n">current_split_filepath</span><span class="p">:</span>
			<span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="p">.</span><span class="n">path</span><span class="p">.</span><span class="n">exists</span><span class="p">(</span><span class="n">current_split_filepath</span><span class="p">):</span>
				<span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s">"Split file not found: </span><span class="si">{</span><span class="n">current_split_filepath</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
			<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">current_split_filepath</span><span class="p">,</span> <span class="s">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="nb">file</span><span class="p">:</span> <span class="n">split_data</span> <span class="o">=</span> <span class="n">json</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="nb">file</span><span class="p">)</span>
			<span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">split_data</span><span class="p">[</span><span class="s">'test'</span><span class="p">]:</span>
				<span class="n">all_test_files</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">Path</span><span class="p">(</span><span class="n">entry</span><span class="p">[</span><span class="s">"path"</span><span class="p">]).</span><span class="n">name</span><span class="p">)</span>
				<span class="n">all_test_hashes</span><span class="p">.</span><span class="n">add</span><span class="p">(</span><span class="n">entry</span><span class="p">[</span><span class="s">"hash"</span><span class="p">])</span>
			<span class="n">current_split_filepath</span> <span class="o">=</span> <span class="n">split_data</span><span class="p">.</span><span class="n">get</span><span class="p">(</span><span class="s">"parent"</span><span class="p">)</span>
		
		<span class="k">for</span> <span class="n">context</span> <span class="ow">in</span> <span class="n">split_context</span><span class="p">:</span>  <span class="c1"># Check for overlap between the test set and the model's split_context
</span>			<span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">context</span><span class="p">[</span><span class="s">"train"</span><span class="p">]:</span>  <span class="c1"># Validate against the training set in the context
</span>				<span class="n">filename</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">entry</span><span class="p">[</span><span class="s">"path"</span><span class="p">]).</span><span class="n">name</span>
				<span class="n">file_hash</span> <span class="o">=</span> <span class="n">entry</span><span class="p">[</span><span class="s">"hash"</span><span class="p">]</span>
				<span class="k">if</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">all_test_files</span><span class="p">:</span> <span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s">"Filename </span><span class="si">{</span><span class="n">filename</span><span class="si">}</span><span class="s"> in test set overlaps with training set in model context."</span><span class="p">)</span>
				<span class="k">if</span> <span class="n">file_hash</span> <span class="ow">in</span> <span class="n">all_test_hashes</span><span class="p">:</span>
					<span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s">"File hash </span><span class="si">{</span><span class="n">file_hash</span><span class="si">}</span><span class="s"> in test set overlaps with training set in model context."</span><span class="p">)</span>
		
		<span class="c1"># print("Validation passed: No overlap between test set and model's training split_context.")
</span>	
	<span class="o">@</span><span class="nb">staticmethod</span>
	<span class="k">def</span> <span class="nf">helper_split_annotation_file_according_to_splitfile</span><span class="p">(</span><span class="n">split_filepath</span><span class="p">,</span> <span class="n">loaded_annotation_json_file</span><span class="p">):</span>
		<span class="s">"""
		the loaded_json_file contains list of annotations object. In each annotation object there is a field, according to this example:
		"data": {
	      "image": "\/data\/upload\/3\/ee88667f-13.jpg"
	    }
		Use only the filename in the json file, and ignore it's path. Then, check if the filename is in the split file. If it is not, raise an
		exception. Otherwise, check if it is in the train set or test set.
		The method returns two objects, the loaded_json_file which is filtered for training files, and for test files (but the loaded_json_file keep the
		same structure, just with filtered elements in the top list)
		"""</span>
		<span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">split_filepath</span><span class="p">,</span> <span class="s">'r'</span><span class="p">)</span> <span class="k">as</span> <span class="nb">file</span><span class="p">:</span>  <span class="n">split_data</span> <span class="o">=</span> <span class="n">json</span><span class="p">.</span><span class="n">load</span><span class="p">(</span><span class="nb">file</span><span class="p">)</span>		<span class="c1"># Load the split file
</span>		
		<span class="n">train_files</span> <span class="o">=</span> <span class="p">{</span><span class="n">Path</span><span class="p">(</span><span class="n">entry</span><span class="p">[</span><span class="s">"path"</span><span class="p">]).</span><span class="n">name</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">split_data</span><span class="p">[</span><span class="s">'train'</span><span class="p">]}</span>
		<span class="n">test_files</span> <span class="o">=</span> <span class="p">{</span><span class="n">Path</span><span class="p">(</span><span class="n">entry</span><span class="p">[</span><span class="s">"path"</span><span class="p">]).</span><span class="n">name</span> <span class="k">for</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">split_data</span><span class="p">[</span><span class="s">'test'</span><span class="p">]}</span>
		<span class="n">all_split_files</span> <span class="o">=</span> <span class="n">train_files</span> <span class="o">|</span> <span class="n">test_files</span>

		<span class="n">train_annotations</span><span class="p">,</span> <span class="n">test_annotations</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span>
		
		<span class="k">for</span> <span class="n">annotation</span> <span class="ow">in</span> <span class="n">loaded_annotation_json_file</span><span class="p">:</span>
			<span class="c1"># Extract the filename from the annotation
</span>			<span class="n">annotated_image_path</span> <span class="o">=</span> <span class="n">annotation</span><span class="p">[</span><span class="s">"data"</span><span class="p">][</span><span class="s">"image"</span><span class="p">]</span>
			<span class="n">annotated_image_filename</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">annotated_image_path</span><span class="p">).</span><span class="n">name</span>
			
			<span class="c1"># Check if the file is in the split file
</span>			<span class="k">if</span> <span class="n">annotated_image_filename</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">all_split_files</span><span class="p">:</span>
				<span class="k">raise</span> <span class="nb">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s">"File </span><span class="si">{</span><span class="n">annotated_image_filename</span><span class="si">}</span><span class="s"> in annotations is not listed in the split file."</span><span class="p">)</span>
			
			<span class="c1"># Assign to train or test set
</span>			<span class="k">if</span> <span class="n">annotated_image_filename</span> <span class="ow">in</span> <span class="n">train_files</span><span class="p">:</span>
				<span class="n">train_annotations</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">annotation</span><span class="p">)</span>
			<span class="k">elif</span> <span class="n">annotated_image_filename</span> <span class="ow">in</span> <span class="n">test_files</span><span class="p">:</span>
				<span class="n">test_annotations</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">annotation</span><span class="p">)</span>
			<span class="k">else</span><span class="p">:</span> <span class="k">raise</span> <span class="nb">Exception</span><span class="p">(</span><span class="s">"This should never happen."</span><span class="p">)</span>
		
		<span class="k">return</span> <span class="n">train_annotations</span><span class="p">,</span> <span class="n">test_annotations</span>
	
<span class="k">if</span> <span class="n">__name__</span> <span class="o">==</span> <span class="s">'__main__'</span><span class="p">:</span>
	<span class="n">DatasetSplitter</span><span class="p">.</span><span class="n">create_or_update_root_split_file</span><span class="p">(</span>
		<span class="n">all_files</span><span class="o">=</span><span class="n">DatasetSplitter</span><span class="p">.</span><span class="n">folder_to_list_of_files</span><span class="p">(</span><span class="s">'/image_storage/'</span><span class="p">),</span>
		<span class="n">split_file_path</span><span class="o">=</span><span class="s">'split.json'</span><span class="p">,</span>
		<span class="n">train_size</span><span class="o">=</span><span class="mf">0.75</span>
	<span class="p">)</span>
	
</code></pre></div></div>]]></content><author><name>Nadav Benedek</name></author><category term="data" /><category term="leakage," /><category term="data" /><category term="preparation," /><category term="Preprocessing" /><category term="Leakage," /><category term="feature" /><category term="leakage," /><category term="feature" /><category term="engineering" /><summary type="html"><![CDATA[Are we allowed to transform the input data in any way we want? Can we train sub-models to preprocess features? Can we use a pipeline of models? Can we use the output of one model as an input of another model?]]></summary></entry><entry><title type="html">Shares Efficient Frontier</title><link href="https://nadavb.com/Shares-Efficient-Frontier/" rel="alternate" type="text/html" title="Shares Efficient Frontier" /><published>2023-08-05T06:20:00+03:00</published><updated>2023-08-05T06:20:00+03:00</updated><id>https://nadavb.com/Shares%20Efficient%20Frontier</id><content type="html" xml:base="https://nadavb.com/Shares-Efficient-Frontier/"><![CDATA[<p>When we invest in a financial instrument (share, index, etc), we usually mostly care about its average annual yield and its variance, which is also called risk or volatility. Naturally, no one can predict the future, so the only thing we can do is to look at the past and assume that the past will reflect the future. So, if we analyze the past history of an instrument, for let’s say 10 or 20 years, we can measure the average annual yield and the variance, or more specifically the standard deviation (which is the square root of the variance, just so it will have the same units as the yield).</p>

<p>Now, imagine that we plot a map with two axis: the stdev (risk) axis, and the yield, it will look like this:</p>

<p><img src="/assets/efficient_frontier/nasdaq-yield-vs-stdev.png" alt="" /></p>

<p>So we can easily observe, for example, that the NASDAQ has higher yield than the S&amp;P 500, but higher variance. If an instrument has a better yield and lower variance over a second instrument, we would prefer to invest in the first one. This is called a Pareto-better instrument.</p>

<p>Okay. So far, so good. But what happens when we invest in a mixture of instruments? How will the combined coordinate be? So, the combined or weighted average of the yield will simply be the combined weighted average of the individual instruments. But the variance - this is something else. The stdev can sometimes be even lower than <em>each</em> of the instruments. This can happen when the covariance between the instruments is not perfect. In math, when you have two random variables, and you average them together, the resulting stdev is:</p>

\[\frac{1}{2} \sqrt{stdev(X)^2+stdev(Y)^2+2Cov(X,Y)}\]

<p>So if $stdev(X)=10=stdev(Y)$ and the variables are <em>independent</em>, meaning their covariance is zero, you get that the combined stdev is 7, which is lower than 10. This is the mathematical grounding for diversifying the risk. You take two risky instruments, invest in both of them, and reduce the risk.</p>

<p>Here you can see the correlation matrix between some instruments:</p>

<p><img src="/assets/efficient_frontier/nasdaq-correlation-matrix.png" alt="" /></p>

<p>So using this notion, you can invest in a mixture (portfolio) of instruments that will give you pareto-better results, rather than investing in some of the instruments alone, as can be seen here:</p>

<p><img src="/assets/efficient_frontier/nasdaq-efficient-frontier.png" alt="" /></p>

<p>So, you can observe many points, that are pareto-better than the Dow Jones or the S&amp;P 500.</p>

\[\square\]

<p>Here’s the code:</p>

<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code>
<span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="n">pd</span>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="n">plt</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="n">np</span>
<span class="kn">import</span> <span class="nn">seaborn</span> <span class="k">as</span> <span class="n">sns</span>
<span class="kn">import</span> <span class="nn">mplcursors</span>  

<span class="n">val_col_name</span> <span class="o">=</span> <span class="s">'Close'</span>
<span class="n">base</span> <span class="o">=</span> <span class="mi">4</span>  

<span class="n">list_shares</span> <span class="o">=</span> <span class="p">[</span>
				<span class="p">(</span><span class="s">'https://gist.githubusercontent.com/ndvbd/30d8069937f945e492bd440a003296c7/raw/a119c81f4fb3d13d4f5b7b03c6cf0f4d6c778cdf/SP500.csv'</span><span class="p">,</span> <span class="s">'SP500'</span><span class="p">),</span>
				<span class="p">(</span><span class="s">'https://gist.githubusercontent.com/ndvbd/2a4516b0f18129287b9de4708f5ce2bf/raw/c69988fc5993699b1bce21c18b4dd1623cb7cb6d/NASDAQ.csv'</span><span class="p">,</span> <span class="s">'NASDAQ'</span><span class="p">),</span>
				<span class="p">(</span><span class="s">'https://gist.githubusercontent.com/ndvbd/01cb8aa365e212041037ca44e1068dba/raw/2dc76c20a3a354a641bfa5e0322adc3bc5dfff77/DOW.csv'</span><span class="p">,</span> <span class="s">'DOW'</span><span class="p">),</span>
					<span class="p">(</span><span class="s">'https://gist.githubusercontent.com/ndvbd/039f3a31ce29c71cbc8433c9c4d0380e/raw/805b94bde56597caf464c885e681f850d90d6243/XLP.csv'</span><span class="p">,</span> <span class="s">'XLP'</span><span class="p">),</span>

<span class="p">]</span>

<span class="n">pandas</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">list_shares</span><span class="p">)):</span>
	<span class="n">read_df</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">list_shares</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">index_col</span><span class="o">=</span><span class="s">'Date'</span><span class="p">,</span> <span class="n">parse_dates</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
	<span class="n">first_date</span> <span class="o">=</span> <span class="n">read_df</span><span class="p">.</span><span class="n">index</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
	<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"first date of </span><span class="si">{</span><span class="n">list_shares</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s">: </span><span class="si">{</span><span class="n">first_date</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
	<span class="n">pandas</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">read_df</span><span class="p">)</span>

<span class="n">suffixes</span> <span class="o">=</span> <span class="p">[</span><span class="sa">f</span><span class="s">'_</span><span class="si">{</span><span class="n">name</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s">'</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">list_shares</span><span class="p">]</span>

<span class="n">returns</span> <span class="o">=</span> <span class="n">pandas</span><span class="p">[</span><span class="mi">0</span><span class="p">][[</span><span class="n">val_col_name</span><span class="p">]].</span><span class="n">rename</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">{</span><span class="n">val_col_name</span><span class="p">:</span> <span class="sa">f</span><span class="s">'</span><span class="si">{</span><span class="n">val_col_name</span><span class="si">}{</span><span class="n">suffixes</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s">'</span><span class="p">})</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">pandas</span><span class="p">)):</span>
	<span class="n">returns</span> <span class="o">=</span> <span class="n">returns</span><span class="p">.</span><span class="n">merge</span><span class="p">(</span>
		<span class="n">pandas</span><span class="p">[</span><span class="n">i</span><span class="p">][[</span><span class="n">val_col_name</span><span class="p">]].</span><span class="n">rename</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">{</span><span class="n">val_col_name</span><span class="p">:</span> <span class="sa">f</span><span class="s">'</span><span class="si">{</span><span class="n">val_col_name</span><span class="si">}{</span><span class="n">suffixes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s">'</span><span class="p">})</span>
		<span class="p">,</span> <span class="n">left_index</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">right_index</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">how</span><span class="o">=</span><span class="s">'inner'</span><span class="p">)</span>

<span class="n">result</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DataFrame</span><span class="p">()</span>
<span class="n">current_date</span> <span class="o">=</span> <span class="n">returns</span><span class="p">.</span><span class="n">index</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>

<span class="k">while</span> <span class="n">current_date</span> <span class="o">&lt;=</span> <span class="n">returns</span><span class="p">.</span><span class="n">index</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span>

	<span class="n">sample</span> <span class="o">=</span> <span class="n">returns</span><span class="p">.</span><span class="n">loc</span><span class="p">[</span><span class="n">current_date</span><span class="p">:</span><span class="n">current_date</span><span class="p">]</span>

	<span class="n">result</span> <span class="o">=</span> <span class="n">pd</span><span class="p">.</span><span class="n">concat</span><span class="p">([</span><span class="n">result</span><span class="p">,</span> <span class="n">sample</span><span class="p">])</span>

	<span class="n">current_date</span> <span class="o">+=</span> <span class="n">pd</span><span class="p">.</span><span class="n">DateOffset</span><span class="p">(</span><span class="n">years</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>

	<span class="k">if</span> <span class="n">current_date</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">returns</span><span class="p">.</span><span class="n">index</span><span class="p">:</span>

		<span class="n">future_dates</span> <span class="o">=</span> <span class="n">returns</span><span class="p">.</span><span class="n">index</span><span class="p">[</span><span class="n">returns</span><span class="p">.</span><span class="n">index</span> <span class="o">&gt;=</span> <span class="n">current_date</span><span class="p">]</span>
		<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">future_dates</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">:</span>
			<span class="n">current_date</span> <span class="o">=</span> <span class="n">future_dates</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
		<span class="k">else</span><span class="p">:</span>
			<span class="k">break</span>

<span class="n">result</span> <span class="o">=</span> <span class="n">result</span><span class="p">.</span><span class="n">reset_index</span><span class="p">()</span>
<span class="n">returns</span> <span class="o">=</span> <span class="n">result</span>

<span class="n">change_df</span> <span class="o">=</span> <span class="n">returns</span><span class="p">.</span><span class="n">copy</span><span class="p">()</span>

<span class="n">shares_mean_std</span> <span class="o">=</span> <span class="p">[]</span>


<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">list_shares</span><span class="p">)):</span>

	<span class="n">pct_change</span> <span class="o">=</span> <span class="n">change_df</span><span class="p">[</span><span class="sa">f</span><span class="s">'Close</span><span class="si">{</span><span class="n">suffixes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s">'</span><span class="p">].</span><span class="n">pct_change</span><span class="p">()</span>
	<span class="n">mean</span> <span class="o">=</span> <span class="n">pct_change</span><span class="p">[</span><span class="mi">1</span><span class="p">:].</span><span class="n">mean</span><span class="p">()</span>
	<span class="n">stdev</span> <span class="o">=</span> <span class="n">pct_change</span><span class="p">[</span><span class="mi">1</span><span class="p">:].</span><span class="n">std</span><span class="p">()</span>
	<span class="n">number_years</span> <span class="o">=</span> <span class="n">pct_change</span><span class="p">.</span><span class="n">count</span><span class="p">()</span>
	<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"For </span><span class="si">{</span><span class="n">list_shares</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s"> we have:  mean: </span><span class="si">{</span><span class="n">mean</span><span class="si">:</span><span class="p">.</span><span class="mi">3</span><span class="n">f</span><span class="si">}</span><span class="s">, stdev: </span><span class="si">{</span><span class="n">stdev</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">, number_years: </span><span class="si">{</span><span class="n">number_years</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
	<span class="n">shares_mean_std</span><span class="p">.</span><span class="n">append</span><span class="p">((</span><span class="n">mean</span><span class="p">,</span> <span class="n">stdev</span><span class="p">))</span>

	<span class="n">change_df</span><span class="p">[</span><span class="sa">f</span><span class="s">'Daily Return</span><span class="si">{</span><span class="n">suffixes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s">'</span><span class="p">]</span> <span class="o">=</span> <span class="n">pct_change</span>

	<span class="n">change_df</span><span class="p">.</span><span class="n">drop</span><span class="p">(</span><span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="sa">f</span><span class="s">'Close</span><span class="si">{</span><span class="n">suffixes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s">'</span><span class="p">],</span> <span class="n">inplace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="n">change_df</span><span class="p">.</span><span class="n">dropna</span><span class="p">(</span><span class="n">inplace</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>

<span class="n">to_plot</span> <span class="o">=</span> <span class="p">[]</span>

<span class="k">def</span> <span class="nf">to_arbitrary_base</span><span class="p">(</span><span class="n">number</span><span class="p">,</span> <span class="n">base</span><span class="p">,</span> <span class="n">pad_to</span><span class="p">):</span>
	<span class="n">digits</span> <span class="o">=</span> <span class="p">[]</span>
	<span class="k">while</span> <span class="n">number</span><span class="p">:</span>
		<span class="n">digits</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">number</span> <span class="o">%</span> <span class="n">base</span><span class="p">))</span>
		<span class="n">number</span> <span class="o">//=</span> <span class="n">base</span>

	<span class="n">digits</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">digits</span><span class="p">)</span>
	<span class="n">padded_array</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">pad</span><span class="p">(</span><span class="n">digits</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">pad_to</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">digits</span><span class="p">)),</span> <span class="s">'constant'</span><span class="p">)</span>

	<span class="k">return</span> <span class="n">padded_array</span>

<span class="k">def</span> <span class="nf">get_composed_earning_for_weight</span><span class="p">(</span><span class="n">random_weights</span><span class="p">):</span>

	<span class="n">list_of_gains</span> <span class="o">=</span> <span class="p">[</span><span class="n">random_weights</span><span class="p">]</span>
	<span class="n">current_earning</span> <span class="o">=</span> <span class="n">random_weights</span><span class="p">.</span><span class="n">copy</span><span class="p">()</span>
	<span class="k">for</span> <span class="n">year</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">change_df</span><span class="p">)):</span>
		<span class="n">current_earning</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">change_df</span><span class="p">.</span><span class="n">iloc</span><span class="p">[</span><span class="n">year</span><span class="p">].</span><span class="n">values</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span> <span class="o">*</span> <span class="n">current_earning</span>
		<span class="n">list_of_gains</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">current_earning</span><span class="p">)</span>

	<span class="n">gain_list</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">(</span><span class="n">list_of_gains</span><span class="p">)</span>
	<span class="n">row_sums</span> <span class="o">=</span> <span class="n">gain_list</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
	<span class="k">return</span> <span class="n">row_sums</span>

<span class="n">max_base</span> <span class="o">=</span> <span class="n">base</span> <span class="o">**</span> <span class="nb">len</span><span class="p">(</span><span class="n">list_shares</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"max_base: </span><span class="si">{</span><span class="n">max_base</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">max_base</span><span class="p">):</span>
	<span class="k">if</span> <span class="bp">False</span><span class="p">:</span>
		<span class="n">random_weights</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">rand</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">list_shares</span><span class="p">))</span>
	<span class="k">else</span><span class="p">:</span>

		<span class="n">random_weights</span> <span class="o">=</span> <span class="n">to_arbitrary_base</span><span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">base</span><span class="o">=</span><span class="n">base</span><span class="p">,</span> <span class="n">pad_to</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">list_shares</span><span class="p">))</span> <span class="o">/</span> <span class="p">(</span><span class="n">base</span><span class="o">-</span><span class="mf">1.0</span><span class="p">)</span>
		<span class="k">if</span> <span class="n">random_weights</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span> <span class="o">==</span> <span class="mf">0.0</span><span class="p">:</span>
			<span class="n">random_weights</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">ones</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">list_shares</span><span class="p">))</span>  

	<span class="n">random_weights</span> <span class="o">=</span> <span class="n">random_weights</span> <span class="o">/</span> <span class="n">random_weights</span><span class="p">.</span><span class="nb">sum</span><span class="p">()</span>

	<span class="n">row_sums</span> <span class="o">=</span> <span class="n">get_composed_earning_for_weight</span><span class="p">(</span><span class="n">random_weights</span><span class="p">)</span>

	<span class="n">percent_increase</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">diff</span><span class="p">(</span><span class="n">row_sums</span><span class="p">)</span> <span class="o">/</span> <span class="n">row_sums</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
	<span class="n">mean</span><span class="p">,</span> <span class="n">std</span> <span class="o">=</span> <span class="n">percent_increase</span><span class="p">.</span><span class="n">mean</span><span class="p">(),</span> <span class="n">percent_increase</span><span class="p">.</span><span class="n">std</span><span class="p">(</span><span class="n">ddof</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>  

	<span class="n">to_plot</span><span class="p">.</span><span class="n">append</span><span class="p">((</span><span class="n">mean</span><span class="p">,</span> <span class="n">std</span><span class="p">,</span> <span class="n">random_weights</span><span class="p">))</span>

<span class="k">if</span> <span class="bp">True</span><span class="p">:</span>
	<span class="n">y_values</span><span class="p">,</span> <span class="n">x_values</span><span class="p">,</span> <span class="n">random_weights</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span> <span class="n">to_plot</span><span class="p">)</span>
	<span class="n">plt</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">x_values</span><span class="p">,</span> <span class="n">y_values</span><span class="p">)</span>

	<span class="n">cursor_hover</span> <span class="o">=</span> <span class="n">mplcursors</span><span class="p">.</span><span class="n">cursor</span><span class="p">(</span><span class="n">hover</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
	<span class="o">@</span><span class="n">cursor_hover</span><span class="p">.</span><span class="n">connect</span><span class="p">(</span><span class="s">"add"</span><span class="p">)</span>
	<span class="k">def</span> <span class="nf">on_add</span><span class="p">(</span><span class="n">sel</span><span class="p">):</span>
		<span class="n">index</span> <span class="o">=</span> <span class="n">sel</span><span class="p">.</span><span class="n">index</span>
		<span class="n">sel</span><span class="p">.</span><span class="n">annotation</span><span class="p">.</span><span class="n">set_text</span><span class="p">(</span><span class="sa">f</span><span class="s">"[Y</span><span class="si">{</span><span class="mf">100.0</span><span class="o">*</span><span class="n">y_values</span><span class="p">[</span><span class="n">index</span><span class="p">]</span><span class="si">:</span><span class="p">.</span><span class="mi">1</span><span class="n">f</span><span class="si">}</span><span class="s">%,</span><span class="si">{</span><span class="mf">100.0</span><span class="o">*</span><span class="n">x_values</span><span class="p">[</span><span class="n">index</span><span class="p">]</span><span class="si">:</span><span class="p">.</span><span class="mi">1</span><span class="n">f</span><span class="si">}</span><span class="s">%]="</span> <span class="o">+</span> <span class="nb">str</span><span class="p">([</span><span class="sa">f</span><span class="s">"</span><span class="si">{</span><span class="n">list_shares</span><span class="p">[</span><span class="n">idx</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s">:</span><span class="si">{</span><span class="n">val</span><span class="si">:</span><span class="p">.</span><span class="mi">2</span><span class="n">f</span><span class="si">}</span><span class="s">"</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">random_weights</span><span class="p">[</span><span class="n">index</span><span class="p">])</span> <span class="p">]))</span>

	<span class="n">cursor_click</span> <span class="o">=</span> <span class="n">mplcursors</span><span class="p">.</span><span class="n">cursor</span><span class="p">()</span>
	<span class="o">@</span><span class="n">cursor_click</span><span class="p">.</span><span class="n">connect</span><span class="p">(</span><span class="s">"add"</span><span class="p">)</span>
	<span class="k">def</span> <span class="nf">on_click</span><span class="p">(</span><span class="n">sel</span><span class="p">):</span>
		<span class="n">index</span> <span class="o">=</span> <span class="n">sel</span><span class="p">.</span><span class="n">index</span>
		<span class="n">weight_vector</span> <span class="o">=</span> <span class="p">[</span><span class="n">val</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">random_weights</span><span class="p">[</span><span class="n">index</span><span class="p">])]</span>
		<span class="k">print</span><span class="p">(</span><span class="sa">f</span><span class="s">"plotting mixture: </span><span class="si">{</span><span class="n">weight_vector</span><span class="si">}</span><span class="s">"</span><span class="p">)</span>
		<span class="n">row_sums</span> <span class="o">=</span> <span class="n">get_composed_earning_for_weight</span><span class="p">(</span><span class="n">weight_vector</span><span class="p">)</span>
		<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">()</span>
		<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">row_sums</span><span class="p">)</span>
		<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">'Plot of Vector'</span><span class="p">)</span>
		<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'year'</span><span class="p">)</span>
		<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'value'</span><span class="p">)</span>
		<span class="n">plt</span><span class="p">.</span><span class="n">xticks</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">row_sums</span><span class="p">)))</span>
		<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>

	<span class="n">additional_y_values</span><span class="p">,</span> <span class="n">additional_x_values</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">shares_mean_std</span><span class="p">)</span>
	<span class="n">scatter</span> <span class="o">=</span> <span class="n">plt</span><span class="p">.</span><span class="n">scatter</span><span class="p">(</span><span class="n">additional_x_values</span><span class="p">,</span> <span class="n">additional_y_values</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'red'</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">'Additional Points'</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span>

	<span class="k">for</span> <span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">additional_x_values</span><span class="p">,</span> <span class="n">additional_y_values</span><span class="p">,</span> <span class="n">suffixes</span><span class="p">):</span>
		<span class="n">plt</span><span class="p">.</span><span class="n">text</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">label</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">9</span><span class="p">,</span> <span class="n">ha</span><span class="o">=</span><span class="s">'right'</span><span class="p">,</span> <span class="n">color</span><span class="o">=</span><span class="s">'red'</span><span class="p">)</span>

	<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">'stdev'</span><span class="p">)</span>
	<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">'Annual Yield'</span><span class="p">)</span>
	<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">'Scatter Plot of (x, y)'</span><span class="p">)</span>
	<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>
<span class="n">correlation</span> <span class="o">=</span> <span class="n">change_df</span><span class="p">.</span><span class="n">corr</span><span class="p">()</span>
<span class="n">sns</span><span class="p">.</span><span class="n">heatmap</span><span class="p">(</span><span class="n">correlation</span><span class="p">,</span> <span class="n">annot</span><span class="o">=</span><span class="bp">True</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s">'coolwarm'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">title</span><span class="p">(</span><span class="s">'Correlation Matrix between FTSE 100 and S&amp;P 500'</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">show</span><span class="p">()</span>


</code></pre></div></div>]]></content><author><name>Nadav Benedek</name></author><category term="tax," /><category term="shares," /><category term="efficient" /><category term="frontier" /><summary type="html"><![CDATA[When we invest in a financial instrument (share, index, etc), we usually mostly care about its average annual yield and its variance, which is also called risk or volatility. Naturally, no one can predict the future, so the only thing we can do is to look at the past and assume that the past will reflect the future. So, if we analyze the past history of an instrument, for let’s say 10 or 20 years, we can measure the average annual yield and the variance, or more specifically the standard deviation (which is the square root of the variance, just so it will have the same units as the yield).]]></summary></entry><entry><title type="html">How much will we lose if we buy and sell shares too frequently? Optimal strategy for cashing out shares, tax harvesting and more.</title><link href="https://nadavb.com/Tax-Impact-On-Selling-Shares/" rel="alternate" type="text/html" title="How much will we lose if we buy and sell shares too frequently? Optimal strategy for cashing out shares, tax harvesting and more." /><published>2023-08-03T06:20:00+03:00</published><updated>2023-08-03T06:20:00+03:00</updated><id>https://nadavb.com/Tax%20Impact%20On%20Selling%20Shares</id><content type="html" xml:base="https://nadavb.com/Tax-Impact-On-Selling-Shares/"><![CDATA[<p>When we sell a share, we need to pay capital tax.</p>

<p>Assume we sell a share, and then immediately buy the same share, or equivalent-yield share, does it affect our gain? Are many transactions good for the gain or bad for the gain? When we have a portfolio of shares, and we need cash, which share should we sell? You can skip the math, and jump right to the last conclusion section.</p>

<h1 id="example-excessive-transaction-in-a-profitable-share">Example: Excessive transaction in a profitable share</h1>

<p>Let’s look at a simple example of a share which is in profit (profitable share): We bought a share at $100, a share that doubles every year, and have a capital-tax of 25%. Today, after 1 year, we have two options:</p>

<ol>
  <li>Sell at \$200, cash left: \$175 (tax paid: \$25). Buy at \$175, sell after another 1 year at \$350 (tax paid: \$43.75), cash: \$306.25.</li>
  <li>Don’t sell at \$200, only after another 1 year when it reaches \$400. Cash after two years: \$325 (tax paid: \$75).</li>
</ol>

<p>So we can see in this simple example that holding the share for two years, and not selling-buying after one year, is optimal.</p>

<p>Why is this happening? How can we explain it? One way to look at it is that when we paid more capital-tax, it means we had more profit. So let’s look at the tax. In the second scenario, we paid tax on the growth of 100-&gt;400=+300 in share price. In the first scenario, we had growth of 100-&gt;200=+100, 175-&gt;350=+175, total +275.</p>

<p>If we split the second scenario into two growths: 100-&gt;200, 200-&gt;400, we can see that the only difference is the run 200-&gt;400 instead of 175-&gt;350.
So in the 2nd scenario, we kept the tax to ourselves, so the run starts from 200 (bad), but the tax doubles itself every year, and the run ends at 400.
While in the 1st scenario, the run starts from 175 (good), but we don’t have the tax to ourselves, so the run ends at 350, instead of 400 of the 2nd scenario.
In other words, we earned +50-25=\$25 more in the 2nd scenario, and the profit is 0.75*\$25 more.</p>

<p>What happens if the share goes down and halves every year? In this case, there’s no tax to pay, and with both scenarios we’re left with $25 after two years. So the conclusion is that even when the share goes up or down, it’s better to minimize unneeded transactions. Furthermore, in practice, sometimes there are additional fees associated with buy/sell, which strengthen the conclusion even more.</p>

<h1 id="example-excessive-transaction-in-a-lossy-share-share-with-current-value-lower-than-the-purchase-value">Example: Excessive transaction in a lossy share (share with current value lower than the purchase value)</h1>

<p>We bought a share at \$100, and have a capital-tax of 25%. Today, after 1 year, the price is \$75, and we know it will double in a year. We have two options:</p>

<ol>
  <li>Sell at \$75 cash left: \$75 (no tax paid, and we can keep the capital loss to offset tax in the future). Buy at \$75, sell after another 1 year at \$150. We have a profit of \$75, but we keep the loss from the previous year of \$25, so we only need to pay tax on a profit of \$50 which is \$12.5, cash: \$137.5.</li>
  <li>Don’t sell at \$75, only after another 1 year when it reaches \$150. Cash after two years: \$137.5 (tax paid: \$12.5).</li>
</ol>

<p>So we can see that in the case of a lossy share, it doesn’t matter if we keep it, or sell and buy it again. Since there was no tax event involved, it does not matter, unlike the case of a profitable share, in which selling and buying again is not a smart move.</p>

<h1 id="the-general-case-of-excessive-transactions-of-a-profitable-share">The general case of excessive transactions of a profitable share</h1>

<p>Define $b$ as the buy-price of the share, $y$ as the yield-multiple per year (e.g. a multiple of 1.1), $n$ as the number of years in our experiment, $s$ as the sell price, $t$ as the capital-tax ratio (e.g. $0.25$), $c$ as the net cash after the sell. Then, the sell price and the net cash we have after the tax reduction are:</p>

\[s = b * y^n\]

\[c = s - (s-b) * t\]

<p>Injecting the first equation into the second, we can write the cash after a single sell:</p>

\[c = by^n - (by^n-b) * t = by^n-by^nt+bt = b(y^n-y^nt+t) = b(y^n(1-t)+t)\]

<p>How much do we lose from frequent/unnecessary intermediate sales?</p>

<p>Define $f$ as the frequency of sales (number of sales during the $n$ years of our experiment), so the effective yield for each sub period of $\frac{n}{f}$ years is:</p>

\[p = (y^n)^{1/f}  = y^{n/f}\]

<p>So the cash after $f$ periods of buy/sell is:</p>

\[b (y^{\frac{n}{f}} (1-t)+t)^{f}\]

<p>So how much money do we lose from unnecessary sales? Let’s calculate the portion of money we stay with, in comparison to a single sell:</p>

\[\frac{(y^{\frac{n}{f}} (1-t)+t)^{f}}{y^n(1-t)+t} \quad \square\]

<p>For example, if the annual yield is 10% (y=1.1), capital-tax is 25% (t=0.25), number of sales is f=10, and the experiment is for n=10 years, we get a portion of $0.94$. That means we lose 6% due to the 10 sales we did, instead of holding the share and selling it once at the end of the experiment.</p>

<p>If we look at 20 years, and we sell and buy once a year, $y=1.1, t=0.25, f=20, n=20$ we get a portion of $0.80$, meaning that we lose $20\%$, by doing too many transactions. That’s definitely not negligible.</p>

<p>Another example: $y=1.2, t=0.25, f=10, n=10$ we get a portion of $0.83$, meaning that we lose $17\%$.</p>

<h3 id="cashing-out">Cashing out</h3>

<p>We have a portfolio of two or more shares, and we need to cash out. Which share should we sell?</p>

<h4 id="two-different-capital-tax-shares">Two Different Capital-Tax Shares</h4>

<p>What happens if we hold two shares, each with different capital tax rules, and we need to liquidate and get some cash out? Is it better to sell the high-tax share or the low-tax share? Instinctively, it’s intuitive to want to sell the low-tax share, to pay less tax, right?</p>

<p>We buy share A at \$100 with capital-tax of 25% and buy share B at \$100 with capital-tax of 0%, both double every year. 
After 1 year we need \$100 in cash to buy a TV. Our options are either sell from share A or sell from share B:</p>

<ol>
  <li>
    <p>After 1 year, share A is valued at \$200. If we decide to sell 57% of our share A position, we cash out 0.57 * (\$200 - 25% * (\$200 - \$100)) = \$100 to buy our TV, and we’re left with 0.43 shares A and 1.0 of share B which we didn’t sell. Wait another year, and we now sell 0.43 shares of A to have cash of 0.43*( \$400 - (\$400 - \$100) * 25%)= \$139.75 . Then sell share B, pay \$0 in taxes, and get \$400. Total cash after two years: \$539.75.</p>
  </li>
  <li>
    <p>After 1 year, share B is valued at \$200. Sell \$100 out of share B, pay no taxes, and get the \$100 cash to buy our TV. So we’re left with 0.5 shares  of B and 1.0 of share A. Wait another year. Sell 1.0 shares of A to get cash of  1.0*( \$400 - 25% * (\$400 - \$100))= \$325 . Sell 0.5 share B, pay zero taxes, and get \$200. Total cash after two years: \$525.</p>
  </li>
</ol>

<p>In this very specific example, we see that we need to sell share A with the higher capital-tax, but this is <strong>not always the case</strong>, and it depends on all the other parameters, as shown in the following paragraphs.</p>

<p>It is <strong>not true</strong> to state that we should keep the share with the lower/higher capital tax, in all cases.</p>

<h1 id="two-shares-with-identical-yield-but-different-purchase-price-current-price-taxation-level">Two shares with identical yield but different purchase price, current price, taxation level</h1>

<p>We bought share A at $b_a$, valued today at $v_a$ with capital-tax of $t_a$ (e.g. 0.25) and bought share B at $b_b$, valued today at $v_b$ with capital-tax of $t_b$, both has yield-multiple per year of $y$ (e.g. 1.1). We now need $c$ in cash to buy something, and we plan to hold the shares for $n$ years, until the end of our experiment. Our options are either sell from share A or sell from share B:</p>

<p>If we decide to sell a fraction of $\frac{c}{v_a - t_a(v_a - b_a)}$ of our <strong>share A</strong> position, we cash out exactly $\frac{c}{v_a - t_a(v_a - b_a)}* (v_a - t_a(v_a - b_a)) = c $ to buy our TV, and we’re left with $1-\frac{c}{v_a - t_a(v_a - b_a)}$ shares of A and 1.0 shares of B. 
Wait $n$ years, and we now sell $1-\frac{c}{v_a - t_a(v_a - b_a)}$ shares of A to have cash of</p>

\[(1 - \frac{c}{v_a - t_a(v_a - b_a)}) * ( v_a y^n(1-t_a) + b_a  t_a)\]

<p>Then sell 1.0 share B to get cash of $ (v_b y^n (1-t_b) + b_b t_b) $.</p>

<script type="text/x-mathjax-config">
  MathJax.Hub.Config({
    TeX: {
      equationNumbers: { autoNumber: "AMS" },
      tagSide: "right"
    },
    tex2jax: {
      inlineMath: [ ['$','$'], ["\\(","\\)"] ],
      displayMath: [ ['$$','$$'], ["\\[","\\]"] ],
      processEscapes: true
    }
  });
  MathJax.Hub.Register.StartupHook("TeX AMSmath Ready", function () {
    MathJax.InputJax.TeX.Stack.Item.AMSarray.Augment({
      clearTag() {
        if (!this.global.notags) {
          this.super(arguments).clearTag.call(this);
        }
      }
    });
  });
</script>

<script type="text/javascript" charset="utf-8" src="https://cdn.jsdelivr.net/npm/mathjax@2/MathJax.js?config=TeX-AMS_CHTML">
</script>

<p>Total cash after $n$ years:</p>

\[\begin{equation}
  \label{eq:1}
  \begin{aligned}
    (1 - \frac{c}{v_a(1-t_a) +   b_a  t_a}) * ( v_a y^n(1-t_a) + b_a  t_a) +      \underbrace{v_b y^n (1-t_b) + b_b t_b}_{\text{X}}         
  \end{aligned}
\end{equation}\]

<p>If we decide to cash out from <strong>share B</strong>, we just need to swap variables, and the cash we get after $n$ years is:</p>

\[\begin{equation}
  \label{eq:2}
  \begin{aligned}
    (1 - \frac{c}{v_b (1-t_b) + b_b t_b }) * ( v_b y^n(1-t_b) + b_b  t_b) +   \underbrace{v_a y^n (1-t_a) + b_a t_a}_{\text{Y}} 
  \end{aligned}
\end{equation}\]

<p>Let’s compare which expression is higher, and subtract X and Y from both sides, to find the maximal expression:</p>

\[\begin{equation}
  \label{eq:3}
  \begin{aligned}
    ( - \frac{c}{v_a(1-t_a) +   b_a  t_a}) * ( v_a y^n(1-t_a) + b_a  t_a) \: ?  \: ( - \frac{c}{v_b (1-t_b) + b_b t_b }) * ( v_b y^n(1-t_b) + b_b  t_b)
  \end{aligned}
\end{equation}\]

<p>As you can see, we can eliminate c. That’s why the amount of cash we need to cash out does not affect the decision. Let’s divide by $(-c)$ and locate the <strong>minimal</strong> expression:</p>

\[\begin{equation}
  \label{eq:4}
  \begin{aligned}
     \frac{ v_a y^n(1-t_a) + b_a  t_a}{v_a(1-t_a) +   b_a  t_a}  \: \: ?  \: \:  \frac{ v_b y^n(1-t_b) + b_b  t_b}{v_b (1-t_b) + b_b t_b }
  \end{aligned}
\end{equation}\]

<p>It’s nice to see that each side does not mix variables between shares. Let’s look on one side, divide the nominator and the denominator by b, and define $m=\frac{v}{b}$.</p>

<p>So, in the <u>general case</u>, the optimal strategy is to sell the share that has a <strong>minimum</strong> value of the gain:</p>

\[\begin{equation}
  \label{eq:7}
  \begin{aligned}
      G := \frac{ m y^n- t(m y^n- 1)}{m- t(m-   1)}
  \end{aligned}
\end{equation}\]

<p>As we can see, v and b does not appear in the formula anymore, just m. That means that the only parameters that can affect the decision are the <strong>quadruplet of (y, t, m, n)</strong>.</p>

<p>Alternatively, we can denote what we expect the price of the share to be in n years as $F=v y^n$ and we get:</p>

\[\begin{equation}
  \label{eq:6}
  \begin{aligned}
     G := \frac{    \underbrace{F- t(F- b)}_{\text{Future Profit}}       }{   \underbrace{v- t(v - b) }_{\text{Current Profit}} }
  \end{aligned}
\end{equation}\]

<p>If you observe carefully, you see that the nominator is the <strong>net profit</strong> if we sell the share in n years, and the denominator is the <strong>net profit</strong> if we sell the share today. In other words, the golden rule is: <strong>Out of a portfolio of shares, sell the share of which the ratio of future net profit to the current net profit is the lowest</strong>. If we define the ratio between future profit to current profit as the gain G, we should strive to <strong>hold the shares having the highest G</strong>.</p>

<h3 id="if-two-shares-have-the-same-yield-and-taxation">If two shares have the same yield and taxation</h3>

<p>What if two shares have the same yield and taxation? It means that given the same yield and taxation, because of formula (5), two shares with the same m have the same gain. But, do we need to sell the higher m or the lower? Let’s differentiate by m to see how the G changes:</p>

\[\begin{aligned}
     \frac{\partial G}{\partial m} =   \frac{(y^n- t y^n)(m- t(m- 1))-(m y^n- t(m y^n- 1))(1- t)}{(m- t(m-   1))^2} =
  \end{aligned}\]

\[\begin{aligned}     
     \frac{ (1-t) \quad  [  y^nm-ty^nm + y^nt -  m y^n   +  tm y^n - t]     }{(m- t(m-   1))^2}=
  \end{aligned}\]

\[\begin{aligned}     
     \frac{ (1-t) \quad t [   y^n - 1]     }{(m- t(m-   1))^2}
  \end{aligned}\]

<p>Now, since t is positive, (1-t) is positive, the denominator is positive and if the yield is positive then y&gt;1, then also $y^n - 1$ is positive, we can see that the derivative is positive. That means that increasing m increases G and vice versa. And that means that G is monotonously increasing w.r.t. to m, and if we have two (or more) shares with the same yield and taxation, the only thing we need to know is m. We want to sell the share with the lowest m, that is, we <strong><u>want to sell the share with the lowest current price to purchase price ratio</u></strong>, the share that its price has grown by the least multiple since we purchased it.</p>

<p>Observation 1: If you apply this rule to the case when we <strong>bought two lots of the same share/company, each with a different price-per-share</strong>, then since obviously the current price per share is identical (because it’s the same share/symbol), selling the lower <em>m</em> means selling the share/lot with the <strong>higher</strong> purchase price. This is sometimes called  <a href="https://www.investopedia.com/terms/s/specificsharesmethod.asp" target="_blank">Highest Cost Basis Policy</a>, and it’s obvious: if you have 2 shares, and you need to sell 1, if you sell the share with the higher cost, you will pay less tax, and still remain with 1 share.</p>

<p>Observation 2: If you have multiple shares, all with the same yield and taxation, selling the share with the <strong>lowest</strong> m strategy is equivalent to selling the share in which the current tax event is the <strong>smallest</strong> in absolute dollars $ (the proof is similar, by doing a derivative of the current tax payment, w.r.t. to m).</p>

<h3 id="one-share-has-zero-taxation-but-both-have-the-same-yield">One share has zero taxation, but both have the same yield</h3>

<p><strong>If share A has zero taxation, and B has nonzero taxation</strong>, and share A has the same yield (or better) than share B, <strong>it’s always better to keep share A</strong>, and sell B. According to equation (5) we can see share A has G of $y^n$. Let’s prove that the gain of A is higher than the gain of B:</p>

\[\begin{aligned}     
     y^n &gt; \frac{ m y^n- t(m y^n- 1)}{m- t(m -   1)} \rightarrow  (m-t(m-1))y^n &gt; my^n - t(my^n -1)
  \end{aligned}\]

\[\newcommand{\b}[1]{\textbf{#1}}
  \begin{aligned}     
     ty^n &gt; t \rightarrow y^n&gt;1 \quad \square
  \end{aligned}\]

<p>So if the expected yield is positive, the inequality holds, which proves our statement.</p>

<h3 id="concrete-examples">Concrete examples</h3>

<p>And now for some concrete examples:</p>

<p>Example of lower <strong>nonzero</strong> taxation of B, identical yield of A and B, but still with different optimal strategy:</p>

<p>$b_a=400, v_a=1800, t_a=0.25, b_b=400, v_b=1500, t_b=0.15, y=1.05, n=8, c=600$ -&gt; best strategy: A</p>

<p>$b_a=200, v_a=3900, t_a=0.25, b_b=200, v_b=1400, t_b=0.15, y=1.30, n=10, c=300$ -&gt; best strategy: B</p>

<p>What happens if we expect a <strong>different yield for each share</strong>? Should we always keep the share with better yield (assume identical risk=variance)?
Surprisingly and counterintuitively, the answer is no! Again, it depends on all other factors.</p>

<p>Example of better yield in share B, but we still need to cash out B:</p>

<p>$b_a=100, v_a=2000, t_a=0.25, y_a=1.35, b_b=800, v_b=1000, t_b=0.20, y_b=1.40, n=2, c=300$ -&gt; best strategy: sell B</p>

<p>And another surprising result: What happens if both shares have the <strong>same taxation</strong> (as in many real life cases), but the <strong>yield on share B is higher</strong>? Should we keep share B? No! Have a look:</p>

<p>$b_a=100, v_a=1700, t_a=0.30, y_a=1.35, b_b=1000, v_b=1600, t_b=0.30, y_b=1.40, n=2, c=500$ -&gt; best strategy: sell B</p>

<p>Another surprising result: share B has capital tax of 0%, while share A has capital tax of 20%, and we need to sell B. This is because share A has better yield than share B, which incentives to keep it. This is a signal of the counter-effects: In general, we’d like to keep shares with lower taxation and better yield, but in this case the yield factor outweighed the tax factor.</p>

<p>$b_a=200, v_a=2500, t_a=0.20, y_a=1.35, b_b=700, v_b=3500, t_b=0.00, y_b=1.20, n=8, c=700 $ -&gt; best strategy: sell B, \$29569 vs \$34144</p>

<p>There are examples of when <strong>share B has better yield, and better taxation, and we still need to sell it</strong>:</p>

<p>$b_a=100, v_a=2600, t_a=0.25, y_a=1.35, b_b=700, v_b=900, t_b=0.20, y_b=1.40, n=2, c=400$ -&gt; best strategy: sell B, \$4405 vs \$4408</p>

<p>However, if share B has a capital tax of exactly zero, share A has nonzero tax, and yield of share A is equal or less to B, we should always keep B (with zero capital tax) and sell A, as proved above.</p>

<p>Sometimes, only the <strong>length of the experiment</strong> affects the decision, when all the other parameters are identical:</p>

<p>$b_a=800, v_a=2600, t_a=0.10, y_a=1.30, b_b=800, v_b=1000, t_b=0.30, y_b=1.35, n=9, c=200 $ -&gt; best strategy: sell A, \$33502 vs \$33290</p>

<p>$b_a=800, v_a=2600, t_a=0.10, y_a=1.30, b_b=800, v_b=1000, t_b=0.30, y_b=1.35, n=5, c=200 $ -&gt; best strategy: sell B, \$11422 vs \$11428</p>

<h1 id="what-happens-when-were-forced-to-materialize-a-profit-should-we-sell-a-lossy-share">What happens when we’re forced to materialize a profit? Should we sell a lossy share?</h1>

<p>We know from previous conclusions, that when we have two shares (with identical future expectations), one in profit, and one in loss, that if we need to cash out, we should sell the one in loss. However, what happens when we are <strong>forced</strong> to sell the share in profit, because of some external constraints, or that we have some information that this company is going to collapse? Or, what happens when we have a profit from some other capital profit (e.g. rent) that can be offset when we sell a share in loss? Should we sell a lossy share to offset positive capital tax?</p>

<p>Let’s define the problem statement: The capital tax is $t$. We materialized a profit of $w$ and bought with it a TV, and because of that we need to pay a tax of $\b{tw}$. We hold 1 unit of a lossy share that we bought at price $b$ and its current value is $v$ ( $\b{v &lt; b} ) $. We consider selling a portion $\b{p}$ of the lossy share. If we sell a little portion, we won’t have cash to pay for the tax $tw$, so we’ll have to take a loan with interest-multiple of $\b{r&gt;1}$ (e.g. 1.05, this can be seen as risk-free interest, or discount rate). If we sell too much, we’ll have spare cash that we will use to buy with it the lossy share. Anyhow, we will sell the share in $n$ years from now (assume the share will be in profit $vy^n&gt;b$), in order to compare the experiments.</p>

<p>The tax that we need to pay now is $T\equiv t(w -  \underbrace{p(b-v)}_{\text{positive}} )$. If the tax is positive, we need to pay it, if it’s negative, we don’t get it as cash from the authorities, but it can be reduced from the tax we’ll pay when we sell the share, in full, after $n$ years from now.</p>

<p>If $T$ is positive, the cash we have now is $C_{tp} \equiv pv-T$. If $C_{tp} &lt;0$ we need a loan for the tax payment, otherwise if $C_{tp} &gt;0$ we can purchase more units of our share. <br />
If $T$ is negative, the cash we have now (positive for sure) is $C_{tn} \equiv pv$ and we can purchase the share with it, and save the tax benefit for the future.</p>

<p>Consider the loan-case where $T&gt;0,C_{tp}&lt;0$. The cash we have after $n$ years is:</p>

\[\begin{aligned}     
      C_{tpcn}\equiv \underbrace{C_{tp}r^n }_{\text{pay the debt}} +   \underbrace{(1-p)(vy^n(1-t)+tb)}_{\text{sell the remaining portion}}       
  \end{aligned}\]

<p>Let’s differentiate it by $p$, to find the optimal strategy:</p>

\[\begin{aligned}     
      C_{tpcn}' =  \underbrace{r^n (v(1-t)+tb) }_{\text{cash if we'd sell share today and lend it}}   -   \underbrace{(vy^n(1-t)+tb) }_{\text{cash if wait n years and then sell}}   
  \end{aligned}\]

<p>It’s interesting to see that the derivative is negative, if it’s better to hold a share, then to sell it now (and lend the money), which is usually the case, otherwise there would be no point in holding any share. So, since $C_{tpcn}’ &lt; 0$  we would like to reduce $p$ to 0 in order to maximize our utility. This means we should <strong>not</strong> share any portion of the lossy share, but just take a loan. At $p=0$ we will have: $C_{tpcn}(p=0)=-twr^n+vy^n(1-t)+tb$ $\square$</p>

<p>Consider the positive cash case where $T&gt;0,C_{tp}&gt;0$. We’ll use all the available cash to buy the share.</p>

\[\begin{aligned}     
      C_{tpcp} = \underbrace{C_{tp}(y^n(1-t)+t)}_{\text{from buying more portion today}}  +   \underbrace{(1-p)(vy^n(1-t)+tb)}_{\text{sell the remaining portion}}  
  \end{aligned}\]

<p>Let’s differentiate w.r.t. to p:</p>

\[\begin{aligned}     
      C_{tpcp}' =  (v(1-t) + tb)(y^n(1-t)+t)-  (vy^n(1-t)+tb)   
  \end{aligned}\]

<p>This will be positive when our assumption ($v &lt; b$) holds.</p>

<details>
  <summary>Click for proof</summary>
  
  $$
  \begin{aligned}     
      C_{tpcp}' =  vy^n(1-t)^2 +  vt(1-t)     + tby^n(1-t) + t^2b  -  vy^n(1-t) - tb  
  \end{aligned}
$$

  $$
  \begin{aligned}     
      =  vy^n(1-t)^2 +  vt(1-t)     + tby^n(1-t) -  vy^n(1-t)  - tb(1-t) 
  \end{aligned}
$$

 $$
  \begin{aligned}     
= (1-t)(  vy^n(1-t) + vt  + tby^n -  vy^n - tb)=t(1-t)(  -vy^n + v  + by^n  - b)
  \end{aligned}
$$
 $$
  \begin{aligned}     
=t(1-t)(y^n-1)(b-v)
  \end{aligned}
$$

Since all elements are positive, the whole expression is positive. $\square$

</details>

<p>$ $</p>

<p>That means increasing $p$ will increase $C_{tpcp}$ and will decrease $T$, until we reach to $T=0$ and in this case p will reach: $w-p(b-v)=0 \rightarrow p = \frac{w}{b-v}$ and at this $p$ we will have:</p>

\[\begin{aligned}     
      C_{tpcp} = pv(y^n(1-t)+t)  +  (1-p)(vy^n(1-t)+tb)  = t(b-w)+vy^n(1-t)
  \end{aligned}\]

<p>Now consider the case of $ T &lt;0 $ (when $p$ is above $\frac{w}{b-v}$):</p>

\[\begin{aligned}     
 C_{tncp}=  \underbrace{pv(y^n(1-t)+t)}_{\text{from buying more portion today}}  +    \underbrace{(1-p)(vy^n(1-t)+tb)}_{\text{sell the remaining portion}}  -  \underbrace{t(w-p(b-v))}_{\text{tax we get back}}  
  \end{aligned}\]

\[\begin{aligned}     
 C_{tncp}'=  v(y^n(1-t)+t) -(vy^n(1-t)+tb) + t(b-v) = vt -tb + t(b-v)=0
  \end{aligned}\]

<p>Since the derivative is 0, it means there’s no point in increasing p to a value more than $\frac{w}{b-v}$, as nothing is changed.</p>

<p>So we have two stationary points: one of $p=\frac{w}{b-v}$ with $ C_{tpcp}$ with no loan, and the other with $p=0$ with $C_{tpcn}$ when we take a loan to pay the tax. Let’s compare the two optimums:</p>

\[\begin{aligned}     
\underbrace{t(b-w)+vy^n(1-t)}_{\text{when we sell a portion to have zero total tax}} \quad  ?  \quad \underbrace{-twr^n+vy^n(1-t)+tb}_{\text{when p=0 and we take a loan}} 
  \end{aligned}\]

\[\begin{aligned}     
\underbrace{-1}_{\text{when we sell a portion to have zero total tax}} \quad  ?  \quad \underbrace{-r^n}_{\text{when p=0 and we take a loan}} 
  \end{aligned}\]

<p>Since $r&gt;1$, we see that the best strategy is <strong>not to take a loan</strong>, but instead to sell a portion of $p=\frac{w}{b-v}$ (or more), which means to sell a portion of the lossy share, <strong>so that the tax payment is zero</strong> (or a higher portion), then we have free cash due to the cashing out portion $p$ of the lossy share, and use the cash, to buy the share again (or a different share, assuming all have the same yield expectation for the future). The tax strategy is sometimes called Tax-loss Harvesting. Selling and buying the same share is sometimes called Wash Sale.</p>

<p>Should we first sell the profitable or first sell the lossy? In some countries, the bank or the tradining institution, collects the tax at the moment you sell a share, and transfers it to the tax authorities. Only at the end of the year, when you do the annual tax report, you get the extra tax back. In these countries, it is better to first sell the lossy share and only then the profitable share. This way, the bank, when you sell the profitable share, takes into account the capital loss from the lossy share, so you pay less tax, and don’t have to wait until the end of the year.</p>

<h1 id="if-we-sold-a-lossy-share-should-we-also-sell--buy-a-profitable-share-to-increase-its-purchase-price">If we sold a lossy share, should we also sell &amp; buy a profitable share to increase its purchase price?</h1>

<p>Let’s follow our usual example. We hold 1 lossy share A (b=\$100, v=\$50), and 1 profitable share B (b=\$100, v=\$200). We sell share A and buy a TV for \$50. Two cases:</p>

<ol>
  <li>We keep the profitable share for one year, and then sell it at \$400. Our profit is \$300 minus the capital loss from previous year of \$50, so we pay tax of \$62.5 and have cash of \$337.5</li>
  <li>We sell the profitable share at \$200, and we have capital profit of \$100-\$50=\$50, so we pay tax of \$12.5 we have cash of \$187.5. So we buy the share at \$187.5. After one year we sell it at \$375 and pay tax of \$46.875 so we have cash of \$328.125</li>
</ol>

<p>We can see it’s better <strong>not to touch the profitable share</strong>, which corroborates our previous conclusion that excessive transactions in profitable shares is disadvantageous.</p>

<h1 id="conclusions"><strong>Conclusions</strong></h1>

<p>Unless we have other considerations, for example diversifying the portfolio to reduce variance (risk) or reduce correlation between shares:</p>

<ol>
  <li>
    <p>Better to <strong>minimize unneeded transactions</strong> of selling and buying a <strong>share in profit</strong>, or switching a share in profit to a different share, due to the tax event we have to pay. Having transaction fees even strengthens this statement.</p>
  </li>
  <li>
    <p>If we <strong>need to cash out</strong>, usually <strong><u>all parameters matter</u></strong>, and in general we need to cash out the share with the lowest gain, G, as defined in (5). 
However, if:</p>

    <p>2.1 If a share has <strong>capital tax of exactly zero</strong>, with <strong>better or equal yield</strong> to the other share, we should always keep it, and the current shares prices or purchase prices do not matter. If the capital tax is zero but the yield is worse, we should consider all other parameters.</p>

    <p>2.2. If two shares have the <strong>same yield and taxation</strong>, we should sell the share with the <strong>lowest m</strong> (lowest price increase in percentage), meaning the lowest current price to purchase price ratio. <strong>This is equivalent to choosing the share to sell, in which the present tax event size in $ is the lowest</strong>. That also means that we should <strong>prefer selling shares in loss rather than shares in profit.</strong></p>

    <p>2.3 All parameters matter, <strong>even if one share has a better yield than the other</strong>, and <strong>even if one share has a better yield and better taxation than the other</strong>. In some cases we need to sell the share with the higher yield. In some cases we need to sell the share with the lower taxation ratio.</p>

    <p>2.4  Higher capital-tax on a share A increases the chances we need to sell A, but other parameters can change the decision.</p>

    <p>2.5  Higher yield on a share A increases the chances we need to sell B, but other parameters can change the decision. Combining the two last statements: <strong>In general it’s more likely to keep shares with lower taxation and better yield, but other parameters can change the decision</strong>.</p>
  </li>
  <li>
    <p><strong>If we had a capital profit</strong>, we should offset it by selling a share in loss, in such an amount that will zero out our capital profit for this year, and buy the share again (or equivalent share). In countries when tax is collected by the trader, it is better to first sell the lossy share, and then the profitable share.</p>
  </li>
</ol>

<p>In general, you should <strong><u>try to avoid actions which results in a tax payment today</u></strong>.</p>

\[\square\]

<p>Here’s a   <a href="https://docs.google.com/spreadsheets/d/1O9_H30AT-8z8WN2Z-YT-9GS1_SAzSmROk3yo4pOR3Mk/edit?usp=sharing" target="_blank">sheet</a> table in which you can enter the parameters and make the right decision of which share to cash out.</p>]]></content><author><name>Nadav Benedek</name></author><category term="tax," /><category term="shares" /><summary type="html"><![CDATA[When we sell a share, we need to pay capital tax.]]></summary></entry><entry><title type="html">Distance between two lines in 3D</title><link href="https://nadavb.com/distance-two-lines-3d/" rel="alternate" type="text/html" title="Distance between two lines in 3D" /><published>2023-07-10T14:22:00+03:00</published><updated>2023-07-10T14:22:00+03:00</updated><id>https://nadavb.com/distance-two-lines-3d</id><content type="html" xml:base="https://nadavb.com/distance-two-lines-3d/"><![CDATA[<p>Assume you have two parametric lines: $p_1=r_1+e_1$ and $p_2=r_2+e_2$</p>

<p>Start by analyzing if they are parallel, by checking the if the normalized directions vectors ($e$) are identical. If they are, pick any point on $p_1$, and run the point-to-line formula (google it).</p>

<p>Otherwise, continue as follows:
The definition of ‘distance’ is the minimum distance between any two points A,B on the two lines. So assume points A,B are the ones who provide the minimum distance between the lines.<br />
Now, the line AB must be perpendicular to both lines $p_1,p_2$. Why? Because otherwise you can move a bit on one of the lines to make the distance shorter. If you move $\epsilon$, the distance will be reduced by $\epsilon \cdot cos(\alpha)$, where $\alpha \neq 90 deg$</p>

<p>So the only vector direction which is perpendicular to both lines is $n=e_1 \times e_2$. This is by definition of cross product. Let’s define $\hat{n}$ as its normalized version. Great.</p>

<p>Now imagine that we place the coordinates origin at point B. Since $AB$ is perpendicular to any point $p_1$ and specifically to point $r_1$, we have a right angled triangle $r_1AB$, and the distance $|AB|=d$ is <br />
$d=\hat{n} \cdot (r_1-B) $ <br />
Great now let’s write B as $B=r_2 + (B-r_2)$ and we get:<br />
$d=\hat{n} \cdot (r_1-r_2-(B-r_2))$ <br />
Now, we know that the vector $B-r_2$ in perpendicular to $\hat{n}$, by definition, so its dot product is zero. Therefore: <br />
\(d=\hat{n} \cdot (r_1 - r_2) \: \: \square\)</p>]]></content><author><name>Nadav Benedek</name></author><category term="distance-3d-lines" /><summary type="html"><![CDATA[Assume you have two parametric lines: $p_1=r_1+e_1$ and $p_2=r_2+e_2$]]></summary></entry><entry><title type="html">Intuitive explanation for the max-min inequality: Why min-max is always greater than max-min.</title><link href="https://nadavb.com/max_min_inequality_intuitive_explanation/" rel="alternate" type="text/html" title="Intuitive explanation for the max-min inequality: Why min-max is always greater than max-min." /><published>2023-04-05T14:22:00+03:00</published><updated>2023-04-05T14:22:00+03:00</updated><id>https://nadavb.com/max_min_inequality_intuitive_explanation</id><content type="html" xml:base="https://nadavb.com/max_min_inequality_intuitive_explanation/"><![CDATA[<p>Min-Max is always greater than Max-Min:</p>

\[min_y max_x f(x,y) \geq max_x min_y f(x,y)\]

<p>Why?</p>

<p>Look at the following table, showing a simple function f(x,y) values for x,y=1,2,3. At the top you see the minimum of every column, which is the min-y, and on the right side the maximum of every row, that is max-x.</p>

<table>
  <thead>
    <tr>
      <th>y \ min-y</th>
      <th>2</th>
      <th>1</th>
      <th>1</th>
      <th>max-x</th>
    </tr>
  </thead>
  <tbody>
    <tr>
      <td>3</td>
      <td>5</td>
      <td>1</td>
      <td>1</td>
      <td>5</td>
    </tr>
    <tr>
      <td>2</td>
      <td>8</td>
      <td>1</td>
      <td>3</td>
      <td>8</td>
    </tr>
    <tr>
      <td>1</td>
      <td>2</td>
      <td>6</td>
      <td>4</td>
      <td>6</td>
    </tr>
    <tr>
      <td>x</td>
      <td>1</td>
      <td>2</td>
      <td>3</td>
      <td> </td>
    </tr>
  </tbody>
</table>

<p>Let’s prove that <em>every</em> number in max_x column is greater than <em>any</em> number in min_y row. For simplicity, every time we write ‘greater’ we mean ‘greater or equal’.</p>

<p>Can it be that a number in max_x is less than some number in min_y?</p>

<p>Let’s say we want to reduce the number 6 in max_x to be less than the number 2 in min_y. That means we need to replace the whole row to be ones, so the number in max_x will be 1, but then the number 2 in min_y will also be 1, since min_y is taking the minimums.</p>

<p>In general, every row and column we look at, we have some number in the intersection. Let’s call this number a. The corresponding number in max_x will always be greater than a, by definition, and the corresponding number in min_y will always be lower than a, by definition. That’s why for any row and column we choose, the corresponding number in max_x will always be higher than the corresponding number in min_y.</p>

<p>And that means every number in the max_x column is greater (or equal) than <em>all</em> the numbers in the min_y row.</p>

<p>If it’s true in general for any two numbers in max_x and min_y, it must be true for the specific number \(min  (max_x)\) and the number \(max (min_y)\), therefore:</p>

\[min_y max_x f(x,y) \geq max_x min_y f(x,y) \hspace{1cm} \square\]]]></content><author><name>Nadav Benedek</name></author><category term="max-min" /><summary type="html"><![CDATA[Min-Max is always greater than Max-Min:]]></summary></entry><entry><title type="html">ChatGPT - How does it work?</title><link href="https://nadavb.com/ChatGPT/" rel="alternate" type="text/html" title="ChatGPT - How does it work?" /><published>2023-01-07T05:20:00+02:00</published><updated>2023-01-07T05:20:00+02:00</updated><id>https://nadavb.com/ChatGPT</id><content type="html" xml:base="https://nadavb.com/ChatGPT/"><![CDATA[<p>ChatGPT, how does it work: <a href="https://www.youtube.com/watch?v=g-jRKS8zZaw">Youtube</a>.</p>]]></content><author><name>Nadav Benedek</name></author><category term="ChatGPT" /><summary type="html"><![CDATA[ChatGPT, how does it work: Youtube.]]></summary></entry><entry><title type="html">Various Tips, Tricks, and Anecdotes for Training Neural Networks</title><link href="https://nadavb.com/Tips-and-tricks-for-training-neural-networks/" rel="alternate" type="text/html" title="Various Tips, Tricks, and Anecdotes for Training Neural Networks" /><published>2022-10-05T11:27:00+03:00</published><updated>2022-10-05T11:27:00+03:00</updated><id>https://nadavb.com/Tips%20and%20tricks%20for%20training%20neural%20networks</id><content type="html" xml:base="https://nadavb.com/Tips-and-tricks-for-training-neural-networks/"><![CDATA[<h4 id="finetuning-a-pretrained-model-architecture-vs-training-from-scratch">Finetuning a pretrained model architecture vs. training from scratch</h4>

<p>I recall a case when I helped with supervised model training. The input was 32x32 image and the output was 7 classes. Our dataset size was around 140. We used augmentation heavily.</p>

<p>When we took a renset32 architecture and trained it from scratch, we got 1.41 test loss, and 0.97 training loss.</p>

<p>When we used a CIFAR10-renset32 pretrained architecture, and continued to finetuning, we got 0.51 test loss and 0.09 training loss. This is a huge improvement. Worth to mention that at this stage, we kept the last fully connected layer of CIFAR10-renset32 intact, while our dataset had only 7 labels and not 10, which did not matter. In addition, the training converged twice as fast as the full training. CIFAR’s dataset size is 60,000, which is larger than our 140 images dataset. Therefore, when the dataset is small, one must try using a pretrained model.</p>

<h4 id="replacing-the-last-linear-layer-of-a-classification-model">Replacing the last linear layer of a classification model</h4>

<p>Often, when one is taking a model architecture, like resnet for example, the best practice is to replace the last linear with a new layer with the correct number of output classes to what you need. However, when you just replace a layer, you lose all the pretrained weights. Does it matter? Is keeping the last layer weights as a starting point important?</p>

<p>Let’s take the previous section problem and dataset and see what happened.</p>

<p>When we used a pretrained CIFAR10-renset32 with 10 classes output, on our 7 classes dataset, we got:
0.51 test loss (0.09 train loss)</p>

<p>When we replaced the last layer with a linear layer with 7 classes output we got:
0.58 test loss (0.25 train loss)</p>

<p>So you can see that the performance is lower.</p>

<p>When we replaced the last layer with a linear layer with 7 classes output while <em>preserving</em> the weights of the relevant neurons, we got:
0.49 test loss (0.09 train loss). 
So, we even improved the performance a bit, and our model has a little bit less parameters.</p>

<p>To conclude, in this case, keeping the pretrained model weights, even when we need to change the last layer, is important.</p>]]></content><author><name>Nadav Benedek</name></author><category term="train," /><category term="neural" /><category term="network," /><category term="tips," /><category term="tricks" /><summary type="html"><![CDATA[Finetuning a pretrained model architecture vs. training from scratch]]></summary></entry></feed>