|
| 1 | +/- |
| 2 | +Copyright (c) 2017 Johannes Hölzl. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Author: Johannes Hölzl |
| 5 | +
|
| 6 | +Probability mass function -- discrete probability measures |
| 7 | +-/ |
| 8 | +import analysis.nnreal analysis.ennreal analysis.topology.infinite_sum |
| 9 | +noncomputable theory |
| 10 | +variables {α : Type*} {β : Type*} {γ : Type*} |
| 11 | +local attribute [instance] classical.prop_decidable |
| 12 | + |
| 13 | +/-- Probability mass functions, i.e. discrete probability measures -/ |
| 14 | +def {u} pmf (α : Type u) : Type u := { f : α → nnreal // is_sum f 1 } |
| 15 | + |
| 16 | +namespace pmf |
| 17 | + |
| 18 | +instance : has_coe_to_fun (pmf α) := ⟨λp, α → nnreal, λp a, p.1 a⟩ |
| 19 | + |
| 20 | +@[extensionality] protected lemma ext : ∀{p q : pmf α}, (∀a, p a = q a) → p = q |
| 21 | +| ⟨f, hf⟩ ⟨g, hg⟩ eq := subtype.eq $ funext eq |
| 22 | + |
| 23 | +lemma is_sum_coe_one (p : pmf α) : is_sum p 1 := p.2 |
| 24 | + |
| 25 | +lemma has_sum_coe (p : pmf α) : has_sum p := has_sum_spec p.is_sum_coe_one |
| 26 | + |
| 27 | +@[simp] lemma tsum_coe (p : pmf α) : (∑a, p a) = 1 := tsum_eq_is_sum p.is_sum_coe_one |
| 28 | + |
| 29 | +def support (p : pmf α) : set α := {a | p.1 a ≠ 0} |
| 30 | + |
| 31 | +def pure (a : α) : pmf α := ⟨λa', if a' = a then 1 else 0, is_sum_ite _ _⟩ |
| 32 | + |
| 33 | +@[simp] lemma pure_apply (a a' : α) : pure a a' = (if a' = a then 1 else 0) := rfl |
| 34 | + |
| 35 | +instance [inhabited α] : inhabited (pmf α) := ⟨pure (default α)⟩ |
| 36 | + |
| 37 | +lemma coe_le_one (p : pmf α) (a : α) : p a ≤ 1 := |
| 38 | +is_sum_le (by intro b; split_ifs; simp [h]; exact le_refl _) (is_sum_ite a (p a)) p.2 |
| 39 | + |
| 40 | +protected lemma bind.has_sum (p : pmf α) (f : α → pmf β) (b : β) : has_sum (λa:α, p a * f a b) := |
| 41 | +begin |
| 42 | + refine nnreal.has_sum_of_le (assume a, _) p.has_sum_coe, |
| 43 | + suffices : p a * f a b ≤ p a * 1, { simpa }, |
| 44 | + exact mul_le_mul_of_nonneg_left ((f a).coe_le_one _) (p a).2 |
| 45 | +end |
| 46 | + |
| 47 | +def bind (p : pmf α) (f : α → pmf β) : pmf β := |
| 48 | +⟨λb, (∑a, p a * f a b), |
| 49 | + begin |
| 50 | + simp [ennreal.is_sum_coe_iff.symm, ennreal.coe_tsum (bind.has_sum p f _)], |
| 51 | + rw [is_sum_iff_of_has_sum ennreal.has_sum, ennreal.tsum_comm], |
| 52 | + simp [ennreal.mul_tsum, (ennreal.coe_tsum (f _).has_sum_coe).symm, |
| 53 | + (ennreal.coe_tsum p.has_sum_coe).symm] |
| 54 | + end⟩ |
| 55 | + |
| 56 | +@[simp] lemma bind_apply (p : pmf α) (f : α → pmf β) (b : β) : p.bind f b = (∑a, p a * f a b) := rfl |
| 57 | + |
| 58 | +lemma coe_bind_apply (p : pmf α) (f : α → pmf β) (b : β) : |
| 59 | + (p.bind f b : ennreal) = (∑a, p a * f a b) := |
| 60 | +eq.trans (ennreal.coe_tsum $ bind.has_sum p f b) $ by simp |
| 61 | + |
| 62 | +@[simp] lemma pure_bind (a : α) (f : α → pmf β) : (pure a).bind f = f a := |
| 63 | +have ∀b a', ite (a' = a) 1 0 * f a' b = ite (a' = a) (f a b) 0, from |
| 64 | + assume b a', by split_ifs; simp; subst h; simp, |
| 65 | +by ext b; simp [this] |
| 66 | + |
| 67 | +@[simp] lemma bind_pure (p : pmf α) : p.bind pure = p := |
| 68 | +have ∀a a', (p a * ite (a' = a) 1 0) = ite (a = a') (p a') 0, from |
| 69 | + assume a a', begin split_ifs; try { subst a }; try { subst a' }; simp * at * end, |
| 70 | +by ext b; simp [this] |
| 71 | + |
| 72 | +@[simp] lemma bind_bind (p : pmf α) (f : α → pmf β) (g : β → pmf γ) : |
| 73 | + (p.bind f).bind g = p.bind (λa, (f a).bind g) := |
| 74 | +begin |
| 75 | + ext b, |
| 76 | + simp only [ennreal.coe_eq_coe.symm, coe_bind_apply, ennreal.mul_tsum.symm, ennreal.tsum_mul.symm], |
| 77 | + rw [ennreal.tsum_comm], |
| 78 | + simp [mul_assoc, mul_left_comm, mul_comm] |
| 79 | +end |
| 80 | + |
| 81 | +lemma bind_comm (p : pmf α) (q : pmf β) (f : α → β → pmf γ) : |
| 82 | + p.bind (λa, q.bind (f a)) = q.bind (λb, p.bind (λa, f a b)) := |
| 83 | +begin |
| 84 | + ext b, |
| 85 | + simp only [ennreal.coe_eq_coe.symm, coe_bind_apply, ennreal.mul_tsum.symm, ennreal.tsum_mul.symm], |
| 86 | + rw [ennreal.tsum_comm], |
| 87 | + simp [mul_assoc, mul_left_comm, mul_comm] |
| 88 | +end |
| 89 | + |
| 90 | +def map (f : α → β) (p : pmf α) : pmf β := bind p (pure ∘ f) |
| 91 | + |
| 92 | +lemma bind_pure_comp (f : α → β) (p : pmf α) : bind p (pure ∘ f) = map f p := rfl |
| 93 | + |
| 94 | +lemma map_id (p : pmf α) : map id p = p := by simp [map] |
| 95 | + |
| 96 | +lemma map_comp (p : pmf α) (f : α → β) (g : β → γ) : (p.map f).map g = p.map (g ∘ f) := |
| 97 | +by simp [map] |
| 98 | + |
| 99 | +lemma pure_map (a : α) (f : α → β) : (pure a).map f = pure (f a) := |
| 100 | +by simp [map] |
| 101 | + |
| 102 | +def seq (f : pmf (α → β)) (p : pmf α) : pmf β := f.bind (λm, p.bind $ λa, pure (m a)) |
| 103 | + |
| 104 | +def of_multiset (s : multiset α) (hs : s ≠ 0) : pmf α := |
| 105 | +⟨λa, s.count a / s.card, |
| 106 | + have s.to_finset.sum (λa, (s.count a : ℝ) / s.card) = 1, |
| 107 | + by simp [div_eq_inv_mul, finset.mul_sum.symm, (finset.sum_nat_cast _ _).symm, hs], |
| 108 | + have s.to_finset.sum (λa, (s.count a : nnreal) / s.card) = 1, |
| 109 | + by rw [← nnreal.eq_iff, nnreal.coe_one, ← this, ← nnreal.sum_coe]; simp, |
| 110 | + begin |
| 111 | + rw ← this, |
| 112 | + apply is_sum_sum_of_ne_finset_zero, |
| 113 | + simp {contextual := tt}, |
| 114 | + end⟩ |
| 115 | + |
| 116 | +end pmf |
0 commit comments