Require Import ssreflect ssrfun ssrbool ssrnat bigop.
Import Prenex Implicits.

Add Rec LoadPath "$ALEA_LIB/ALEA/src" as ALEA.
Add Rec LoadPath "$ALEA_LIB/Continue".
Require Import my_ssralea.
Require Export Prog.
Require Export Cover.
Require Import Ccpo.
Set Implicit Arguments.

 
Open Local Scope U_scope.


(** * Extra Lemmas for Alea 
 *)


Section fixP.

Variables A B : Type.
Variable F : (A -> distr B) -m> (A -> distr B).
Variable q : A -> B -> U.
Variable PR : A -> Prop. 


Lemma Pfixrule_Ulub : forall (p : A -> nat -> U),
   (forall x:A, p x O == 0)->
   (forall (i:nat) (f:A -> distr B),
      (forall x: A, PR x -> ok (p x i) (f x) (q x)) -> 
       forall x: A, PR x -> ok (p x (S i))  (F f x) (q x)) -> 
       forall x: A, PR x -> ok (Ulub (p x)) (Mfix F x) (q x).
red; intros p p0 Hrec x Hx.
assert ( forall (n : nat) (x : A), PR x -> ok (p x n) ((Miter F) n x) (q x)).
 induction n; simpl;auto. 
 intros;red;auto. 
apply Ulub_le; auto.
intro n. 
transitivity (mu (Miter F n x) (q x)). 
apply (H n x Hx).
apply Mfix_le_iter; auto.
Save.


Lemma Pfixrule : forall (p : A -> nat -m> U),
   (forall x:A, p x O == 0)->
   (forall (i:nat) (f:A -> distr B),
   (forall x : A, PR x -> ok ((p x) i) (f x) (q x)) ->
    forall x : A, PR x ->ok ((p x) (S i)) (F f x) (q x)) ->
    forall x : A, PR x -> ok (lub (p x)) (Mfix F x) (q x).
red; intros.
rewrite <- Ulub_lub.
apply (Pfixrule_Ulub p H H0 H1).
Save.

End fixP. 



Lemma quarterUplus: Uplus [1/4] [1/4] == [1/2].
Proof.
 by (transitivity (2 */ [1/4])%U; auto).
Qed. 

Lemma quarterUplusn n:  
 Uplus ([1/4] * n)  ([1/4] * n) == ([1/2] * n)%U.
Proof. 
rewrite Umult_sym. rewrite -Udistr_plus_left.
rewrite quarterUplus. auto. 
auto. 
Qed.  


Definition pmin2 n := match n with
 | O => 0
 | 1 => 0
 | S (S n) => pmin 1 n
end.

Instance pmin2_mon : monotonic pmin2.
red; auto.
intros x y H;unfold pmin2. 
case:x H;case:y=>//=. 
 intros n H; by destruct (le_Sn_0 n). 
intros a b;case:a;case b=>//=. 
 intros n H;apply le_S_n in H;by destruct (le_Sn_0 n). 
intros n n' H; apply pmin_le_compat=>//. 
apply le_S_n in H;by apply le_S_n in H.
Qed. 

Definition Pmin2 :nat -m> U := mon pmin2.

Lemma lubp2: lub Pmin2 == 1%U.
Proof.
apply Uge_one_eq;apply Ule_lt_lim; intros.
assert (exc (fun n : nat => match n with |O => t <= 0
 |S n => t <= 1 - [1/]1+n end));last first.
 apply H0;auto; intros n H1.
 case:n H1.
  by intro;apply Ule_zero_eq in H1;rewrite H1.
 intros;transitivity (1 - [1/]1+n); auto.
 transitivity (pmin 1 n); auto.
 apply (le_lub (Pmin2) n.+2).

generalize (Ueq_orc t 0);unfold orc.  
intros;unfold exc;intros. 
apply (H0 C H1);intros.
apply (H2 O);rewrite H3;auto. 
assert (forall n,  t <= 1 - [1/]1+n -> C).
 intros;apply (H2 n.+1);auto. 
apply (@Ult_le_nth_minus t 1 H);auto. 
Qed.



Definition pmin1 n := match n with
 | O => 0
 | (S n) => pmin 1 n
end.

Instance pmin1_mon : monotonic pmin1.
red; auto.
intros x y H;unfold pmin1. 
case:x H;case:y=>//=. 
 intros n H; by destruct (le_Sn_0 n). 
intros n n' H; apply pmin_le_compat=>//. 
by apply le_S_n in H.
Qed. 

Definition Pmin1 :nat -m> U := mon pmin1.

Lemma lubp1: lub Pmin1 == 1%U.
Proof.
apply Uge_one_eq;apply Ule_lt_lim; intros.
assert (exc (fun n : nat => match n with |O => t <= 0
 |S n => t <= 1 - [1/]1+n end));last first.
 apply H0;auto; intros n H1.
 case:n H1.
  by intro;apply Ule_zero_eq in H1;rewrite H1.
 intros;transitivity (1 - [1/]1+n); auto.
 transitivity (pmin 1 n); auto.
 apply (le_lub (Pmin1) n.+1).

generalize (Ueq_orc t 0);unfold orc.  
intros;unfold exc;intros. 
apply (H0 C H1);intros.
apply (H2 O);rewrite H3;auto. 
assert (forall n,  t <= 1 - [1/]1+n -> C).
 intros;apply (H2 n.+1);auto. 
apply (@Ult_le_nth_minus t 1 H);auto. 
Qed.




Definition pcte1 (p:U) n := match n with
 | O => 0
 | (S n) => p
end.

Instance pcte1_mon : forall p, monotonic (pcte1 p).
red; auto.
intros p x y H;unfold pcte1.
case:x H=>//;case y=>//. 
intros. by destruct (le_Sn_0 n). 
Defined.

Definition Pcte1 (p:U) :nat -m> U := mon (pcte1 p).

Lemma lubpcte1 : forall p, lub (Pcte1 p) == p%U.
Proof.
intro. 
generalize (le_lub (Pcte1 p)).
 simpl;unfold pcte1. intro. 
 have h:= (H 1%nat). clear H.
generalize (lub_le (Pcte1 p)). 
 simpl. intro. have h' :=(H p).
apply Ole_antisym;auto.
apply h';intro n. 
unfold pcte1. case:n=>//. 
Qed. 





Definition pqmin (p:U) (q:U) (n:nat) :=  p - ( q ^ n ).

Instance pqmin_mon : forall p q, monotonic (pqmin p q).
red;unfold pqmin;intros. 
apply Uminus_le_compat_right.
apply Uexp_le_compat;auto.
Qed. 

Definition Pqmin (p q:U) :nat -m> U := mon (pqmin p q).

Definition Uq1min := Pqmin 1.

Lemma eq_lim_Uq1min : forall q, q < 1 -> lub (Uq1min q) == 1.
Proof.
intros;unfold Uq1min,Pqmin,pqmin.
transitivity (mlub (fun n : nat => [1-] q ^ n));auto.  
Qed.

Lemma Uq1min_S : forall n p, 
 (Uq1min ([1-]p)) (S n) == p + (Uq1min ([1-]p)) n * ([1-]p).
intros;simpl. unfold pqmin;simpl.
rewrite Uminus_distr_left. Usimpl. 
rewrite Umult_sym. 
rewrite Uplus_minus_assoc_right;auto.
rewrite Uinv_opp_right.  auto.
Qed. 

Lemma Uq1min_0 : forall q, (Uq1min q) O == 0.
Proof.
unfold Uq1min; simpl; auto.
Qed. 



Lemma compn_morph : forall (f : U -> U-> U) (x: U) 
 (u1 : nat -> U) (u2: nat -> U) (n: nat),
 (forall x y x0 y0 : U, x == y -> x0 == y0 -> f x x0 == f y y0) ->              
 (forall y, u1 y == u2 y) -> compn f x u1 n == compn f x u2 n.
Proof.
move => f x u1 u2 n hmf h; elim : n =>//.
move=> n;simpl;have := (h n);apply hmf.
Qed.


Lemma sigma_compo : forall (f : nat -> nat -> U) (a b:nat),
 (forall x y,f x y == f y x) -> 
(sigma (fun k : nat => (sigma (fun l : nat => f k l) b)) a) ==
(sigma (fun k : nat => (sigma (fun l : nat => f k l) a)) b).
Proof.
intros f a b H;move:b;induction a. 
 induction b =>/=;auto. 
intro b;rewrite sigma_S;rewrite IHa;induction b;auto. 
symmetry;rewrite sigma_S -IHb.
symmetry;rewrite Uplus_sym sigma_S Uplus_perm3 -Uplus_assoc Uplus_sym 
 -Uplus_assoc.
apply Uplus_eq_compat=>//.
do 2 rewrite sigma_S; rewrite Uplus_sym.
symmetry;rewrite -Uplus_assoc Uplus_sym -Uplus_assoc.
apply Uplus_eq_compat=>//.
rewrite Uplus_sym;apply Uplus_eq_compat=>//.
Qed.

Lemma mu_cond_le : forall (A : Type) (m : distr A) (f g : MF A),
(mu m) (fconj f g) <= (mu m) f. 
Proof. 
move=>a  m f g.
rewrite -(Umult_one_left ((mu m)f)) Mcond_conj.
by apply Umult_le_compat.
Qed. 


(** * Extra Lemmas for R Alea 
 *)


Section Rplus.

Require Import Rplus. 
Open Scope Rp_scope.

Lemma Rp_double1 : forall x, 
 (2 * x) == (x + x).
Proof.
move=>x. 
by rewrite NRpmult_mult NRpmult_S NRpmult_1.
Qed. 

Lemma N2Rp_S_plus_1 : forall n, N2Rp (S n) == R1 + n.
Proof.
intros; rewrite N2Rp_plus; auto.
Qed.

Lemma divn1 : forall n, 
(U2Rp ([1/]1+n)) + R1 == n.+2 * (U2Rp ([1/]1+n)).
Proof.
move=>n;rewrite -(@Rp1div_left (n.+1)) U2Rp_Unth.
have := (Rpdistr_plus_right R1 (N2Rp n.+1) ([1/]N2Rp n.+1)).
rewrite -N2Rp_S_plus_1=>->;auto.
Qed.

Lemma Rpsigma_const : forall (n : nat) (x : Rp), 
 (Rpsigma (fun _ : nat => x)) n == (n * x)%Rp.
Proof.
induction n; auto.
intro x;rewrite Rpsigma_S IHn N2Rp_S_plus_1 Rpdistr_plus_right;auto. 
Qed. 

Lemma Unth_mult_eq : forall x, 
 (U2Rp([1/]1+x) * x.+1)%Rp == R1.  
Proof. 
move=>x;rewrite -(@Rp1div_right (x.+1)).
apply Rpmult_eq_compat_left.
rewrite U2Rp_Unth;auto. 
Qed. 

Close Scope Rp_scope. 
End Rplus.

Open Local Scope U_scope.
Open Local Scope O_scope.

Lemma sigma_mult_perm :
  forall (f : nat -> U) n c1 c2, retract (fun k => c1 * (f k)) n -> retract (fun k => c2 * (f k)) n
  -> c1 * (sigma (fun k => c2 * (f k)) n) == c2 * (sigma (fun k => c1 * (f k)) n).
Proof.
intros; transitivity (sigma (fun k => c2 * (c1 * (f k))) n); auto.
rewrite <- sigma_mult; auto.
Qed.

Lemma Rpsigma_U2Rp : forall (f : nat -> U) n, retract f n 
    -> Rpsigma f n == sigma f n.
induction n; intros.
rewrite Rpsigma_0; rewrite sigma_0; auto.
rewrite Rpsigma_S;rewrite sigma_S.
rewrite <- U2Rp_plus_le; auto.
Save.
Hint Resolve Rpsigma_U2Rp.

Lemma sigma_dist1 : forall n (f:nat -> U), 
  [1/]1+n.+1 * (sigma (fun i => [1/]1+n * f i)) n.+1 +
  (sigma (fun i => [1/]1+n.+1 * f i)) n.+1             ==
  (sigma (fun i => [1/]1+n * f i)) n.+1.
Proof.
move=>n f;rewrite sigma_mult_perm;auto;last by apply retract_pred;auto.
apply U2Rp_eq_simpl;rewrite -U2Rp_plus_le.
 rewrite U2Rp_mult.
 repeat rewrite -Rpsigma_U2Rp;auto;last by apply retract_pred;auto. 
 rewrite -(Rpmult_one_left (fmont (Rpsigma 
   (fun x : nat => U2Rp ([1/]1+n.+1  * (f x)))) n.+1)). 
 rewrite Rpmult_assoc -Rpdistr_plus_right Rpmult_one_right divn1 -Rpmult_assoc 
  -Rpsigma_mult. 
 rewrite (Rpsigma_eq_compat _ 
   (fun k : nat => (U2Rp ([1/]1+n.+1) * U2Rp ([1/]1+n * f k))%Rp));last first.
  move=>k hk;repeat rewrite U2Rp_mult;rewrite Rpmult_sym 
   (Rpmult_sym (U2Rp ([1/]1+n)));auto.
 rewrite Rpsigma_mult Rpmult_assoc U2Rp_Unth Rp1div_left;auto.
apply U2Rp_le_simpl;rewrite U2Rp_mult U2Rp_Uinv.
rewrite -Rpsigma_U2Rp;last by apply retract_pred;auto.
apply Rpplus_le_perm_right.
rewrite -(Rpmult_one_left (fmont (Rpsigma (fun x : nat => 
 U2Rp ([1/]1+n.+1  * (f x)))) n.+1)). 
rewrite Rpmult_assoc -Rpdistr_plus_right Rpmult_one_right divn1 -Rpmult_assoc 
 -Rpsigma_mult. 
rewrite (Rpsigma_eq_compat _ 
 (fun k : nat => (U2Rp ([1/]1+n.+1) * U2Rp ([1/]1+n * f k))%Rp));last first.
 move=>k hk;repeat rewrite U2Rp_mult. 
 rewrite Rpmult_sym (Rpmult_sym (U2Rp ([1/]1+n)));auto. 
rewrite Rpsigma_mult Rpmult_assoc U2Rp_Unth Rp1div_left Rpmult_one_left
 Rpsigma_U2Rp;auto.
Qed. 

Lemma prod_comp1 : forall (n m:nat) (f:nat->U),
  prod [eta f] n * prod (fun x : nat => f (x + n)%nat) m == 
  prod [eta f] (n+m)%nat.
Proof.
move=>n m;elim: m n.
 move=>n f;rewrite prod_0;Usimpl;auto.
move=>m IHm n f;rewrite addnS prod_S prod_S -IHm.
rewrite Umult_assoc Umult_assoc;apply Umult_eq_compat_left.
rewrite Umult_sym;apply Umult_eq_compat_left.
replace (m+n)%nat with (n+m)%nat=>//.
apply plus_comm. 
Qed.

Lemma prod_comp2 : forall (n:nat) (f g:U),
  prod (fun _ => f) n * prod (fun _ => g) n == 
  prod (fun _ => f * g) n.
Proof.
move=>n;elim:n;first by repeat rewrite prod_0. 
move=>n hind f g. 
repeat rewrite prod_S;rewrite -hind.
do 2 rewrite -Umult_assoc;apply Umult_eq_compat_right=>//.
Qed. 

Lemma ex_le1 : forall a0 a1, a0 <= a1 -> 
 exists x, (@Oeq U ordU (a0 + x) a1) /\ a0 <= [1-] x.
Proof.
move=>a0 a1 h;exists (a1 - a0).
split;auto.  
rewrite -Uminus_plus_perm_right;repeat Usimpl;auto.
Qed.


Lemma id_rem0:  forall (a b: U),
 [1/2]* (a * b) <= [1/2] * a *( [1/2] * a) + [1/2] * b * ([1/2] * b).
Proof.
move=>a0 a1. 
apply (Ule_total a0 a1);auto=>h;apply ex_le1 in h;destruct h as [x [H H']];
rewrite -H Umult_assoc. 
 apply (Ole_trans _ _ _ (Udistr_plus_left_le ([1/2] * a0) a0 x )).
 repeat rewrite Udistr_plus_left;repeat Usimpl;auto.
 repeat rewrite Udistr_plus_right;repeat Usimpl;auto.
 repeat rewrite Uplus_assoc;repeat rewrite -Umult_assoc.
 rewrite Unth_one_refl (Umult_perm3 x) (Umult_sym x [1/2]) -Uplus_assoc.
 rewrite -Uplus_assoc (Uplus_assoc ([1/2] * (a0 * ([1/2] * x)))) Unth_one_refl. 
 rewrite Umult_perm2;repeat Usimpl;rewrite Umult_perm2;auto.
rewrite -Umult_assoc (Umult_sym _ a1) Umult_assoc.  
apply (Ole_trans _ _ _ (Udistr_plus_left_le ([1/2] * a1) a1 x)).
repeat rewrite Udistr_plus_left;repeat Usimpl;auto.
repeat rewrite Udistr_plus_right;repeat Usimpl;auto.
repeat rewrite -Uplus_assoc. 
rewrite (Uplus_assoc ([1/2] * x * ([1/2] * a1))) (Umult_sym ([1/2] *x)). 
repeat rewrite -Umult_assoc.  
rewrite Unth_one_refl Uplus_sym (Uplus_sym (  [1/2] * (a1 * ([1/2] * a1)))). 
rewrite (Umult_perm2 a1);repeat rewrite -Uplus_assoc;repeat Usimpl. 
rewrite Umult_perm3;auto.
Qed.

Lemma id_rem1 : forall (a b: U),
 a * b <= ([1/2] * a + [1/2] * b) * ([1/2] * a + [1/2] * b).
Proof.
move=>a0 a1;rewrite Udistr_plus_left;auto.
rewrite Udistr_plus_right;auto. 
rewrite Udistr_plus_right;auto.
rewrite Uplus_sym Uplus_assoc Uplus_sym Uplus_assoc Uplus_assoc. 
rewrite -Umult_assoc -Umult_assoc (Umult_perm3 a1) (Umult_sym a1). 
rewrite Unth_one_refl -(Unth_one_refl (a0 * a1)) Umult_perm2 -Uplus_assoc. 
apply Uplus_le_compat=>//.
rewrite -id_rem0 Umult_perm2;repeat Usimpl;auto. 
Qed.

Lemma id_rem2 : forall (a b: U) (n:nat),
 prod (fun _ => a * b) n <= prod (fun _ =>[1/2] * a + [1/2] * b ) (2 * n)%nat.
Proof.
move=>a0 a1 n;rewrite mul2n -addnn -prod_comp1 prod_comp2.
apply prod_le_compat.
move=>k hk;apply id_rem1. 
Qed.

Lemma prod_sigma_id2 : forall (n:nat) (f:nat->U),
  prod
     (fun _ : nat =>
      (sigma (fun j : nat => [1/]1+n * f j)) n.+1 *
      (sigma (fun j : nat => [1/]1+n * f (j + n.+1)%nat)) n.+1) n.+1 <=
   prod
     (fun _ : nat =>
      (sigma (fun j : nat => [1/]1+(2 * n).+1 * f j)) (2 * n).+2) 
     (2 * n).+2.
Proof.
move=>n f. 
apply (Ole_trans _ _ _ (id_rem2 
        ((sigma (fun j : nat => [1/]1+n * f j)) n.+1) 
        ((sigma (fun j : nat => [1/]1+n * f (j + n.+1)%nat)) n.+1) 
        n.+1)). 
replace (2 * n.+1)%nat with (2*n)%nat.+2;last by rewrite mul2n mul2n doubleS.
apply prod_le_compat;move=>k hk.
rewrite -sigma_mult;auto;rewrite -sigma_mult;auto. 
have h:= (sigma_plus_lift (fun i => [1/]1+(2*n).+1* f i) n.+1 n.+1).
replace (n.+1 + n.+1)%coq_nat with((2*n).+2) in h;
 last by rewrite mul2n -doubleS -addnn.
rewrite h.
have hn : (@Oeq U ordU ([1/](n.+1 + n.+1)%coq_nat) ([1/]1+(2 * n)%nat.+1)).
 apply Unth_eq_compat=>/=;rewrite mul2n -addnn;auto. 
apply Uplus_le_compat; apply sigma_le_compat;
 move=>k' hk';rewrite Umult_assoc;repeat Usimpl; 
 rewrite -(@Unth_half n.+1 (lt_0_Sn n));rewrite hn=>//;repeat Usimpl.
replace (n.+1 + k')%coq_nat with (k' + n.+1)%coq_nat=>//.
by rewrite plus_comm.
Qed.




(** * Not null probability 
 *)

Lemma proba_not_null : forall (A:Type) (t:A) (m:distr A) (f : MF A) 
 (P: A -> A -> U),
  (forall a b, 0 < P a b -> f a == f b) ->
  0 < mu m (fun x => P x t) -> 0 < f t ->
  0 < mu m f.
Proof.
move=>A t m f P h h0 h1.
apply (Olt_le_trans _ (mu m (fun x => Umult (P x t) (f t)))).
 rewrite mu_stable_mult_right. apply (Umult_lt_zero h0 h1).
rewrite (@mu_eq_compat _ _ _ (Oeq_refl m) (fun x=> (P x t * f t)%U)
 (fun x=> (P x t * f x)%U));auto. 
intro x. apply (Ueq_orc 0 (P x t));auto=>h2.
 rewrite -h2;repeat Usimpl;trivial.
by rewrite (h _ _ (Ult_neq_zero _ h2)).
Qed.



(** * Two independent events 
 *)


Definition indep (A:Type) (m:distr A)(f g : MF A) :=
  mu m (fconj f g) == mu m f * mu m g.


Lemma indep_cond : forall (A:Type) (m:distr A)(f g : MF A), 
  indep m f g -> ~ 0 == mu m f -> mu (Mcond m f) g == (mu m) g.
unfold indep; intros.
rewrite Mcond_simpl.
rewrite H; auto.
Save.

Lemma carac_prod2 : forall (A: Type) (m: distr A) (a b: A -> U), 
 indep m a b -> 
 mu m (fconj a b) == mu m a * mu m b.
Proof.
auto. 
Qed.

Definition fB2U A (a : A -> bool) : A -> U := 
  fun x => B2U (a x).

Definition indepb (A: Type) (m: distr A) (a b: A -> bool) := 
  indep m (fB2U a) (fB2U b).

Lemma carac_prodb : forall (A: Type) (m: distr A) (a b: A -> bool), 
 indepb m a b -> 
 mu m (fB2U (fun (x:A) => andb (a x) (b x))) == mu m (fB2U a) * mu m (fB2U b).
Proof.
intros; transitivity (mu m (fconj (fB2U a) (fB2U b))); auto.
apply mu_eq_compat; auto; intro x.
unfold fB2U,fconj.
destruct (a x); simpl; auto.
Qed.

Lemma indepb_Munit : forall (A:Type) x (f g : A -> bool),
 indepb (Munit x) f g.
intros;unfold indepb,indep;do 2 rewrite Munit_simpl;auto.
Qed. 

Lemma indepb_sym : forall (A:Type) (m:distr A) (f g: A -> bool),
 indepb m f g <-> indepb m g f.
intros;unfold indepb,indep;setoid_rewrite fconj_def.
setoid_rewrite Umult_sym at 1 2;split;auto.
Qed. 


(** * Two composed events 
 *)

Section composed.

Definition Total {A:Type}(DA:distr A) := Oeq (mu DA (fone A)) 1%U.

Variables A B C: Type.
Variable compose : A -> B -> C.

Variable DA : distr A.
Variable DB : distr B.

Hypothesis HA : Total DA. 
Hypothesis HB : Total DB. 

Let F := Mlet DA
        (fun k =>  Mlet DB (fun k' =>  Munit (compose k k'))).

Section on_f.

Variables (fA : A -> U)(fB:B->U)(fC : C->U).  

Hypothesis HAB: forall a b,  Oeq (fC (compose a b)) (fA a * fB b)%U.

Let X := mu F fC.


Lemma L00: X == mu DA  (fun x => ((fA x) * (mu DB fB)) %U). 
Proof.
unfold X,F.
setoid_rewrite Mlet_simpl.
apply mu_eq_compat =>//;intros k.
setoid_rewrite <- mu_stable_mult.
setoid_rewrite Mlet_simpl. 
setoid_rewrite Munit_simpl. 
setoid_rewrite HAB. 
setoid_rewrite fmult_def.
trivial.
Qed.


Lemma L01 : X == (mu DA  fA * mu DB fB)%U. 
Proof.
rewrite L00.
transitivity ((mu DA) (fun x : A => ((mu DB) fB * fA x))%U).
 apply mu_eq_compat;auto. 
setoid_rewrite <-fmult_def.
by rewrite  mu_stable_mult.
Qed.

End on_f.


Lemma F_total : Total F.
Proof.
red;unfold F.
setoid_rewrite Mlet_simpl. 
setoid_rewrite (fun x =>Mlet_simpl DB (fun k' : B => Munit (compose x k'))).
transitivity ((mu DA) (fone A));auto. 
Qed.

End composed.


(** * Discrete sigma distributions 
 *)

Section Discrete_s.

Instance discrete_s_mon : forall A (c : nat -> U) (p : nat -> A) (n:nat),
     monotonic (fun f : A -> U => sigma (fun k => c k * f (p k)) n).
red; intros; auto.
Save.

Definition discrete_s A (c : nat -> U) (p : nat -> A) (n:nat): M A := 
       mon (fun f : A -> U => sigma (fun k => c k * f (p k)) n).

Lemma discrete_s_simpl : forall A (c : nat -> U) (p : nat -> A) f (n:nat), 
     discrete_s c p n f = sigma (fun k => c k * f (p k)) n.
trivial.
Save.

Lemma discrete_s_stable_inv : forall A (c : nat -> U) (p : nat -> A) (n:nat), 
    retract c n -> stable_inv (discrete_s c p n).
red; intros.
repeat rewrite discrete_s_simpl.
unfold finv;rewrite sigma_inv;auto.
Qed. 

Lemma discrete_s_stable_plus : forall A (c : nat -> U) (p : nat -> A) (n:nat), 
    stable_plus (discrete_s c p n).
red; intros.
repeat rewrite discrete_s_simpl.
transitivity (sigma (fun k : nat => c k * f (p k) + c k * g (p k)) n).
apply sigma_eq_compat.
intros; unfold fplus.
apply Udistr_plus_left.
apply (H (p k)).
apply sigma_plus; auto.
Save.

Lemma retract_le : forall (f g : nat->U) (n:nat), f <= g -> retract g n -> 
 retract f n.
red; intros.
transitivity (g k); auto.
transitivity ([1-]sigma g k); auto.
Save.

Lemma discrete_s_stable_mult : forall A (c : nat -> U) (p : nat -> A) (n:nat), 
    retract c n -> stable_mult (discrete_s c p n).
red; intros.
repeat rewrite discrete_s_simpl; unfold fmult.
transitivity (sigma (fun k0 : nat => k * (c k0 * f (p k0))) n); auto. 
apply sigma_mult; apply retract_le with c;auto.
Qed.

Lemma discrete_s_continuous : forall A (c : nat -> U) (p : nat -> A) (n:nat), 
    continuous (discrete_s c p n).
red; intros.
rewrite discrete_s_simpl.
transitivity 
(sigma (lub (ishift (fun k => (UMult (c k) @ (h <o> (p k)))))) n).
apply sigma_le_compat; intros k H.
rewrite fcpo_lub_simpl -UMult_simpl. 
rewrite (UMult_continuous_right  (c k) (h <o> p k)).
apply lub_le_compat;intro m; auto.
rewrite sigma_lub1;apply lub_le_compat;intro m;auto. 
Save.

Record discr_s (A:Type) : Type := 
     {bound_s : nat; coeff_s : nat -> U; 
      coeff_retr_s : retract coeff_s bound_s; points_s : nat -> A}.
Hint Resolve coeff_retr_s.

Definition Discrete_s : forall A,  discr_s A -> distr A.
intros A d ; exists (discrete_s (coeff_s d) (points_s d)(bound_s d)).
apply discrete_s_stable_inv; trivial.
apply discrete_s_stable_plus.
apply discrete_s_stable_mult; trivial.
apply discrete_s_continuous.
Defined.

Lemma Discrete_s_simpl : forall A (d:discr_s A), 
     mu (Discrete_s d) = discrete_s (coeff_s d) (points_s d) (bound_s d).
trivial.
Save.

Definition is_discrete_s (A:Type) (m: distr A) := 
      exists d : discr_s A, m == Discrete_s d.

Lemma discrete_s_commute : forall A B (d1:distr A) (d2:distr B) (f:MF (A*B)),
    is_discrete_s d1 -> prod_distr_com d1 d2 f.
red; intros A B d1 d2 f ((n,cf,cfr,pt),H).
unfold arg_swap; unfold swap; simpl.
transitivity
  (mu d2 (fun x : B => discrete_s cf pt n (fun x0 : A => f (x0, x)))).
unfold discrete_s.
rewrite (mu_sigma_eq d2 (fun k x => cf k * f (pt k, x))).
transitivity (discrete_s cf pt n (fun x : A => mu d2 
 (fun x0 : B => f (x, x0)))); auto.
rewrite discrete_s_simpl; apply sigma_eq_compat; intros k H'.
apply Oeq_sym; apply (mu_stable_mult d2 (cf k) (fun x0 : B => f (pt k, x0))).
intros; apply retract_le with cf; auto.
apply (mu_stable_eq d2); intro x.
transitivity (mu (Discrete_s (Build_discr_s cfr pt)) 
 (fun x0 : A => f (x0, x))); auto.
Save.

Lemma is_discrete_s_swap: forall A B C (d1:distr A) (d2:distr B) 
 (f:A -> B -> distr C), 
   is_discrete_s d1 -> 
   Mlet d1 (fun x => Mlet d2 (fun y => f x y)) == 
   Mlet d2 (fun y => Mlet d1 (fun x => f x y)).
intros A B C d1 d2 f H1 ev.
transitivity (mu (prod_distr d1 d2) (fun c => mu (f (fst c) (snd c)) ev)).
reflexivity.
rewrite (discrete_s_commute  _ _ H1).
simpl; auto.
Qed.

Lemma retract_invn : forall n, retract (fun _ => ([1/]1+n)%U) (S n).
Proof.
auto. 
Qed. 

Lemma is_discrete_Random : forall (n:nat), is_discrete_s (Random n). 
Proof.
intro n.
exists (Build_discr_s (@retract_invn n) (fun k => k)). 
intro f;auto.
Qed.

End Discrete_s.




(** * Conditional probability
 *)
Section Conditionnal.

Require Import ssreflect ssrfun ssrbool eqtype ssrnat.
Require Import fintype finset fingraph seq.
Import Prenex Implicits.
Require Import my_ssr. 
Require Import weird_induc.

Variables (A:Type) (B:finType) (b:B).


(**
      prodConj f a : 
     Product of f applied to each element of B and a
     The result is in [0,1]
 *)

Definition prodConj (f:B->MF A) (a:A) : U :=
  \big[(fun x : U => [eta Umult x])/1]_y f y a.


(** 
     prodConjBound f j a : 
     Product of f applied to each element of B of rank comprised 
     between j.+1 and the cardinality of B, and a
     The result is in [0,1]
 *)

Definition prodConjBound (f:B->MF A) (j:nat) (a:A) : U :=
 \big[(fun x : U => [eta Umult x])/1]_(j.+1 <= i < #|B|)
   f (nth b (enum B) i) a.


Lemma Mcond_prodConj : forall (f:B -> MF A) (m: distr A),
 Term m ->
 mu m (prodConj f) ==
 prod (fun i => mu (Mcond m (prodConjBound f i)) 
                   (f (nth b (enum B) i)) ) 
      #|B|.
Proof.
unfold prodConj,prodConjBound;move=>f m mterm.
have : (0 < #|B|)%nat.
 case h:(0<#|B|)%nat=>//;move/card_gt0P:h=>h;destruct h;by exists b.
rewrite cardE enumT. 
change (Finite.enum B) with (index_enum B). 
elim: (index_enum B)=>[|t s hind]=>//_.
setoid_rewrite big_cons;setoid_rewrite (fun a => Umult_sym (f t a)).
change (size (t::s)) with ((size s).+1).
case:s hind.
 move=>_/=;rewrite prod_S prod_0 nth0;unfold head;Usimpl. 
 rewrite (@fconj_eq_compat _ _ (fun a : A =>1) _ (f t))=>//;
  last by intro;rewrite big_nil.
 setoid_rewrite big_nil;rewrite mterm;auto. 
move=>t' s hind;rewrite Mcond_conj hind;clear hind=>//.
rewrite prod_S_lift nth0;unfold head. 
apply Umult_eq_compat=>//.
 apply mu_eq_compat=>//= x;apply Udiv_eq_compat =>//;
  last by setoid_rewrite (big_nth b);setoid_rewrite big_add1.
 apply mu_eq_compat =>//;apply fconj_eq_compat=>// a;
  by rewrite(big_nth b) big_add1.
apply prod_eq_compat=>//k H0;rewrite -nth_behead;unfold behead. 
apply mu_eq_compat=>//= x;apply Udiv_eq_compat=>//;
 last by symmetry;setoid_rewrite big_add1 at 1.
apply mu_eq_compat=>//;apply fconj_eq_compat=>// a. 
symmetry;by rewrite big_add1.
Qed.


Lemma Mcond_prodConjBound :
 forall (f:B->MF A) (m: distr A) (x:B) (k:nat) (P:B->nat -> bool),
  Term m ->
 
 ~(mu m) (prodConjBound (fun y : B => Uprop.finv (f y)) k) == 0 ->

 (forall x0, f x x0 *
      (\big[(fun x1 : U => [eta Umult x1])/1]_(k.+1 <= i < #|B| | 
       P x i) Uprop.finv (f (nth b (enum B) i)) x0) == f x x0) ->

 indep m (f x)
   (fun x0 : A =>
    \big[(fun x1 : U => [eta Umult x1])/1]_(k.+1 <= i < #|B|)
       (if ~~ P x i then Uprop.finv (f (nth b (enum B) i)) x0 else 1)) ->

 mu m  (f x) <= 
 mu (Mcond m
         (prodConjBound (fun y => Uprop.finv (f y))
            k))
                 (f x) .
Proof.
move=>f m x k P hterm hprod himpl hindep/=;unfold prodConjBound,fconj.
setoid_rewrite (fun a a'=> Umult_sym a' (f x a)).

assert (forall a,\big[(fun x1 : U => [eta Umult x1])/1]_(k.+1 <= i < #|B|)
 Uprop.finv (f (nth b (enum B) i)) a == 
 \big[(fun x1 : U => [eta Umult x1])/1]_(k.+1 <= i < #|B| |P x i) 
   Uprop.finv (f (nth b (enum B) i)) a *
 \big[(fun x1 : U => [eta Umult x1])/1]_(k.+1 <= i < #|B| | ~~ P x i) 
   Uprop.finv (f (nth b (enum B) i)) a).
 generalize (fun x0=>@bigIDs _ _ 1 (fun x1 : U => [eta Umult x1]) _ 
 (index_iota k.+1 #|B|) (P x) (fun _ => true) 
 (fun i => Uprop.finv (f (nth b (enum B) i)) x0))=>/=H0 a;apply H0;auto.

setoid_rewrite H. setoid_rewrite Umult_assoc;setoid_rewrite himpl.  
assert (forall a, 
  @Oeq U ordU (\big[(fun x1 : U => [eta Umult x1])/1]_(k.+1 <= i < #|B|
  | 
      ~~ P x i) (Uprop.finv (f (nth b (enum B) i)) a))
       (\big[(fun x1 : U => [eta Umult x1])/1]_(k.+1 <= i < #|B|)
     (if ~~ (P x i) then (Uprop.finv (f (nth b (enum B) i)) a) else 1))).
move=>a;apply big_mkconds;auto.
setoid_rewrite H0. rewrite carac_prod2=>//. 
apply Umult_div_le_left;auto.
setoid_rewrite <-H0;setoid_rewrite <- H;apply neq_sym.
unfold prodConjBound in hprod;apply hprod.
Qed. 

Require Import Rplus. 
(*Set Printing Coercions.*)

Lemma prod_sigma_average : forall (n:nat) (f:nat-> U),
 prod (fun i => f i) n.+1   <=
 prod (fun _ => sigma (fun i => [1/]1+n * f i) n.+1) n.+1.
Proof.
move=>n f.
assert (forall x : nat, {f x == 0} + {~ f x == 0}).
 move=>x;apply iseq_dec.
generalize (dec_exists_lt (fun k => @Oeq U ordU (f k) 0) H n.+1)=>HYP.
destruct HYP as [e | e].
 destruct e as [k [h1 h2]];rewrite (@prod_zero [eta f]  n.+1 k)=>//.
clear H;move:n f e.
apply (weird_induc  
        (fun n => forall (f:nat -> U),
            ~(exists k, (k < n.+1)%coq_nat /\ f k == 0) -> 
            prod [eta f] n.+1 <=
            prod (fun _ => (sigma (fun j => [1/]1+n * f j )) n.+1) n.+1)
        (fun n => (2*n)%nat.+1));auto with arith;last first.
3:move=>f h0/=;do 2 rewrite prod_S prod_0;by repeat Usimpl.  
+move=>n hind f h0. 
 have h1: (@Ole U ordU (prod [eta f] n.+1) (prod (fun _ : nat => 
             (sigma (fun j : nat => [1/]1+n * f j)) n.+1) n.+1)).
  apply (hind f);move=>[k [hk h]];destruct h0;exists k;split;auto. 
  apply (lt_trans _ _ _ hk)=>//;change n.+1 with  (0+ n.+1)%nat.
  rewrite mul2n - doubleS -addnn;have h' := (ltn_add2r n.+1 0 n.+1).
  rewrite ltn0Sn in h';move/ltP:h';auto.  
 have h2:(@Ole U ordU  (prod (fun x : nat => f (x + n.+1)%N) n.+1)
        (prod (fun _ => (sigma (fun j => [1/]1+n * f (j + n.+1)%N)) n.+1) n.+1)).
  apply (hind (fun k => (f (k + n.+1)%nat))).
  move=>[k [hk h]];destruct h0;exists (k + n.+1)%nat;split;auto.
  rewrite mul2n - doubleS -addnn;auto with arith. 
 have:=(Umult_le_compat _ _ _ _ h1 h2).
 rewrite prod_comp1;replace (n.+1 + n.+1)%nat with ((2 *n)%nat.+2);
  last by rewrite mul2n -doubleS addnn.
 rewrite prod_comp2;move=>h;apply (Ole_trans _ _ _ h).
 apply prod_sigma_id2. 
+move=>n hind f h0. 
 have:= (hind (fun x => if (x == n.+1)%bool 
                       then (sigma (fun j => [1/]1+n * f j) n.+1) 
                       else (f x))).
 rewrite prod_S eq_refl (prod_eq_compat _ (fun x =>f x));last first.
  move=>k hk;case h:(k == n.+1)=>//.
  move/eqP:h hk->=>h;by destruct (lt_irrefl n.+1). 
 rewrite (sigma_S _ n.+1) eq_refl.
 rewrite (prod_eq_compat 
 (fun _ => [1/]1+n.+1 * (sigma (fun j => [1/]1+n * f j)) n.+1 +
           (sigma (fun j => [1/]1+n.+1 * (if j == n.+1
           then (sigma (fun j0 : nat => [1/]1+n * f j0)) n.+1
           else f j))) n.+1)
 (fun _ => [1/]1+n.+1 * (sigma (fun j => [1/]1+n * f j)) n.+1 +
          (sigma (fun i => [1/]1+n.+1 * f i) n.+1)));last first.
  move=>k hk;repeat Usimpl;apply sigma_eq_compat=>k' hk';case h:(k'==n.+1)=>//.
  move/eqP:h hk'-> =>h;by destruct (lt_irrefl n.+1).
 rewrite (prod_eq_compat 
 (fun _ => [1/]1+n.+1 * (sigma (fun j => [1/]1+n * f j)) n.+1 +
      (sigma (fun i=> [1/]1+n.+1 * f i)) n.+1)
 (fun _ => (sigma (fun j => [1/]1+n * f j)) n.+1));last first.
  move=>h hk;apply sigma_dist1. 
 rewrite (prod_S _ n.+1)=>h.
 apply (Umult_le_simpl_left _ _ ((sigma (fun j : nat => [1/]1+n * f j)) n.+1) ).
  move=>h';symmetry in h'. 
  assert ([1/]1+n * f n == 0) as h1.
   apply (sigma_zero_elim _ h');auto. 
  destruct h0;exists n;split;auto. 
  apply (Umult_simpl_left ([1/]1+n));auto;repeat Usimpl;auto. 
 apply h; move=>[k [hk hk']];move:hk';case hk' : (k == n.+1);auto.
 move=>h'. 
 assert ([1/]1+n * f n == 0) as h1.
  apply (sigma_zero_elim _ h');auto. 
 destruct h0;exists n;split;auto. 
  apply (Umult_simpl_left ([1/]1+n));auto;repeat Usimpl;auto. 
 move=>hk'0;destruct h0;exists k;split;auto.
 apply le_lt_eq_dec in hk;destruct hk;auto.
 apply lt_S_n in l;auto.
 apply eq_add_S in e;rewrite e eq_refl in hk'; done. 
Qed.

Lemma sigma_inv_simpl : forall (n:nat) (f: nat -> U),
    sigma (fun i => [1/]1+n * [1-] (f i)) (S n) == [1-] sigma (fun i => [1/]1+n * (f  i)) (S n).
Proof.
intros.
pose (g:= fun (k:nat) => [1/]1+n).
assert (Rg:retract g (S n)) by auto.
transitivity (sigma (fun i : nat => [1/]1+n * [1-] f i) (S n) + [1-] sigma g (S n)).
cbv delta [g].
rewrite <- Unth_sigma_Sn.
rewrite Uinv_one;auto.
symmetry; apply (sigma_inv f Rg).
Qed.

Lemma prod_sigma_averagefin  : forall (f:B-> U), 
 prod (fun i => [1-] f (nth b (enum B) i)) 
      #|B|                        <=
 prod (fun _=> [1-]  
               (sigma (fun i => [1/]1+ #|B|.-1 * f (nth b (enum B) i)) #|B|))
      #|B|.
Proof.
move=>f;have:(0<#|B|)%nat by case h:(0<#|B|)%nat=>//;case/card_gt0P:h;exists b.
case:#|B|=>//n _.
apply (Ole_trans _ (prod (fun _ : nat => (sigma (fun i : nat => [1/]1+n * 
       [1-] f (nth b (enum B) i))) n.+1) n.+1) _);last first.
 apply prod_le_compat=>k Hk;by rewrite sigma_inv_simpl.
apply prod_sigma_average.
Qed.

Lemma forall_exists_fB2U : forall (P: A -> B -> bool), 
(fun x => NB2U [exists y, P x y]) == fB2U (fun x =>[forall y, ~~ P x y]).
Proof.
move=>P x;unfold fB2U,B2U,NB2U.  
rewrite -negb_exists.
case:[exists y, P x y]=>//=.
Qed.   

Lemma finv_fB2U : forall (P: A -> bool), 
 (fB2U (fun y => ~~ P y)) == 
 (Uprop.finv (fB2U (fun y => P y))).
Proof. 
move=>P x;unfold fB2U,B2U,Uprop.finv. 
case:(P x);auto. 
Qed.

Lemma forall_prodConj_fB2U : forall (P: A -> B -> bool),
 fB2U (fun x => [forall y, ~~ P x y]) ==
 prodConj (fun e => Uprop.finv (fB2U (fun s => P s e))).
Proof.
move=>P x;unfold prodConj,fB2U,B2U,Uprop.finv. 
rewrite -big_andE.
elim:(index_enum B);first by do 2 rewrite big_nil.
move=>t s H;do 2 rewrite big_cons;case: (P x t).
 rewrite Bool.andb_false_l;by repeat Usimpl.
repeat Usimpl;by rewrite H.
Qed.


End Conditionnal.


(** *  Alea/Bigop equivalence
 *)

Section Bigop.

Variables (A:finType) (a:A). 

Definition prodOP (f:A->Rp) :=
  \big[Rpplus/R0]_y f y.

Lemma rpsigma_bigop : forall (f:A ->Rp),
  prodOP f == Rpsigma (fun x => f (nth a (enum A) x)) #|A|.
Proof.
move=>f;unfold prodOP,index_enum. 
rewrite -enumT cardE. 
elim:(enum A)=>//. 
 rewrite big_nil Rpsigma_0;auto.
move=>t s H;rewrite big_cons. 
change (size (t::s)) with (size s).+1. 
rewrite Rpsigma_S_lift H.
apply Rpplus_eq_compat=>//.
Qed.


Lemma iter_Rpplus_0 : forall (n:nat) (m :Rp),
 ssrnat.iter n (Rpplus m) O == Rpmult n m.
Proof.
elim.
 intro n. rewrite Rpmult_zero_left. done.
intros n hind m. rewrite iterS hind.
assert ((m + n * m)%Rp == (R1 * m + n * m)%Rp).
 auto. 
rewrite H. clear H. rewrite -Rpdistr_plus_right.
apply Rpmult_eq_compat=>//.
auto. 
Qed. 

Lemma bigRpplusleq : forall (T:finType) (f g:T->Rp),
  (forall v, f v <= g v) ->
  \big[Rpplus/R0]_v (f v) <=  \big[Rpplus/R0]_v (g v).
Proof.
intros. elim:(index_enum T). 
 by repeat rewrite big_nil.
intros. repeat rewrite big_cons. 
apply Rpplus_le_compat=>//.
Qed.

End Bigop.
