(*
 * The LOOP Project
 *
 * The LOOP Team, Dresden University and Nijmegen University
 *
 * Copyright (C) 2002
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2 of
 * the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * General Public License in file COPYING in this or one of the
 * parent directories for more details.
 *
 * Created 14.5.99 by Hendrik
 *
 * Time-stamp: <Monday 24 June 02 17:42:07 tews@ithif51>
 *
 * Variance checking pass for CCSL
 *
 * $Id: variance.ml,v 1.9 2002/07/03 12:01:16 tews Exp $
 *
 *)

open Util
open Global
open Error
open Top_variant_types
open Ccsl_pretty
open Logic_util
open Types_util
;;

(***********************************************************************
 *
 * Errors
 *
 *)

exception Variance_error

let d s = 
  if debug_level _DEBUG_VARIANCE then
    print_verbose s

let adt_not_strict token = 
    begin
      error_message (remove_option token.loc)
	"Carrier must occur only at strictly positive positions in an Adt.";
      raise Variance_error
    end;;

let variance_missing token =
  begin
    error_message (remove_option token.loc)
      "Variance annotation for type parameters is required here.";
    raise Variance_error
  end

let variance_mismatch_error id uvar =
  begin 
    error_message (remove_option id.id_token.loc)
      ("Variance mismatch. Derived variance " ^
       (string_of_ccsl_variance id.id_variance) ^ 
       " for type parameter " ^
       id.id_token.token_name ^ "."
      );
    raise Variance_error
  end      


(***********************************************************************
 * 
 * manipulate variances
 *)


let rec variance_match v uv = match v,uv with
  | Unset,_ -> assert(false)
  | v,uv when v = uv -> true
  | _,Unset -> true
  | Pair(vn,vp),Pair(uvn,uvp) -> 
      if vn <= uvn && vp <= uvp
      then true
      else false
  | Pair _, _ -> variance_match (make_simple v) uv
  | Unused, _ 
  | Pos, Mixed
  | Neg, Mixed
      -> true
  | _ -> false


(* importings from Logic_util 
 * 
 *     val make_simple : variance_type -> variance_type
 * 
 *     val valid_variance : variance_type -> bool
 * 
 *     val variance_subst : variance_type -> variance_type -> variance_type
 * 
 *     val variance_join : variance_type -> variance_type -> variance_type
 *)

let vsubst = variance_subst

let vjoin = variance_join


let is_strictly_positive = function
  | Unused
  | Pair(-1,-1)
  | Pair(-1,0) -> true
  | Unset -> assert(false)
  | _ -> false

let start_var = Pair(-1,0)

let fun_var = Pair(1,-1)

(***********************************************************************
 * 
 * checks for cartesian functors
 *)

type check_cartesian =
  | CheckExtended
  | CheckCartesian
  | CheckFailed
  | NoCheck

let check_cart_fun_step = function
  | CheckExtended -> CheckCartesian,CheckExtended
  | CheckCartesian -> CheckFailed,CheckFailed
  | CheckFailed -> CheckFailed,CheckFailed
  | NoCheck -> NoCheck,NoCheck

let check_cart_class_step cart var = 
  match (cart, make_simple var) with
    | NoCheck,_ -> NoCheck
    | CheckExtended,Unused 
    | CheckExtended,Pos -> CheckExtended
    | CheckExtended,Neg 
    | CheckExtended,Mixed -> CheckCartesian
    | CheckCartesian,_ -> CheckFailed
    | CheckFailed,_ -> CheckFailed
    | _ -> assert(false)

let check_cart_adt_step cart var =
  match (cart, make_simple var) with
    | NoCheck,_ -> NoCheck
    | CheckExtended,Unused 
    | CheckExtended,Pos -> CheckExtended
    | CheckExtended,Neg
    | CheckExtended,Mixed -> CheckCartesian
    | CheckCartesian,Unused
    | CheckCartesian,Pos -> CheckCartesian
    | CheckCartesian,Neg
    | CheckCartesian,Mixed -> CheckFailed
    | CheckFailed,_ -> CheckFailed
    | _ -> assert(false)

let check_cart_ground_step = function
  | CheckExtended -> CheckFailed
  | CheckCartesian -> CheckFailed
  | CheckFailed -> CheckFailed
  | NoCheck -> NoCheck


let functor_of_variance self_var cart_flag param_vars = 
  match self_var,cart_flag with
    | Unused, _
    | Pair(-1,-1),_ -> 
	if List.for_all is_strictly_positive param_vars
	then StrictlyConstantFunctor
	else ConstantFunctor
    | Pair(-1,0),_ -> 
	if List.for_all is_strictly_positive param_vars
	then StrictlyPolynomialFunctor
	else PolynomialFunctor
    | Pair(1,n),true when n <= 0 -> ExtendedCartesianFunctor
    | Pair(1,n),false when n <= 0 -> ExtendedPolynomialFunctor
    | _ -> HigherOrderFunctor


let string_of_functor = function
  | UnknownFunctor		-> "unknown"
  | StrictlyConstantFunctor	-> "strictly constant"
  | ConstantFunctor		-> "constant"
  | StrictlyPolynomialFunctor	-> "strictly polynomial"
  | PolynomialFunctor		-> "polynomial"
  | ExtendedCartesianFunctor	-> "extended cartesian"
  | ExtendedPolynomialFunctor	-> "extended polynomial"
  | HigherOrderFunctor		-> "higher-order polynomial"


(***********************************************************************
 ***********************************************************************
 *
 * variance check over types
 *
 *)

type self_info = {
  mutable self_var : variance_type;
  mutable cartesian_flag : bool
}


  (* mixed self keeps track if we have to enable the
   * MixedSelfInstFeature for the class
   *)
let rec do_type (mixedself : bool ref) sinfo cvar ccart  = function
  | BoundTypeVariable id -> 
      id.id_variance <- vjoin id.id_variance cvar
  | Self 
  | Carrier ->
      sinfo.self_var <- vjoin sinfo.self_var cvar;
      if (ccart <> NoCheck) && ((make_simple cvar) <> Unused) then
	sinfo.cartesian_flag <- sinfo.cartesian_flag & (ccart <> CheckFailed);
      if (make_simple cvar) = Mixed 
      then mixedself := true
  | Bool -> ()
  | Function( dom, codom) ->
      let domcart,codomcart = check_cart_fun_step ccart 
      in
	do_type mixedself sinfo (vsubst cvar fun_var) domcart dom;
	do_type mixedself sinfo cvar codomcart codom
  | Product( tl ) ->
      List.iter (do_type mixedself sinfo cvar ccart) tl
  | Class(cl, args ) ->
      List.iter2 (fun param arg -> match param,arg with
		    | TypeParameter id, TypeArgument typ ->
			do_type mixedself sinfo (vsubst id.id_variance cvar) 
			  (check_cart_class_step ccart id.id_variance)
			  typ
		 ) 
	cl#get_parameters args
  | Adt(adt, flag, args ) ->
      List.iter2 (fun param arg -> match param,arg with
		    | TypeParameter id, TypeArgument typ ->
			do_type mixedself sinfo (vsubst id.id_variance cvar)
			  (check_cart_adt_step ccart id.id_variance)
			  typ
		 ) 
	adt#get_parameters args
  | Groundtype(id, args) ->
      List.iter2 (fun param arg -> match param,arg with
		    | TypeParameter id, TypeArgument typ ->
			do_type mixedself sinfo (vsubst id.id_variance cvar)
			  (check_cart_ground_step ccart)
			  typ
		 ) 
	(get_ground_type_parameters id) args

					(* not in ccsl input types *)
  | Record _
  | TypeConstant _
  | IFace _
  | FreeTypeVariable _
  | Array _
  | Predtype _
  | SmartFunction _ 
    -> assert(false)



(***********************************************************************
 ***********************************************************************
 *
 * top level units
 *
 *)

let do_action (mixedself : bool ref) selfinfo m =
  (* mixed self keeps track if we have to enable the
   * MixedSelfInstFeature for the class
   *)
  do_type mixedself selfinfo start_var CheckExtended m#get_curried_type


let do_adt_constructor selfinfo c =
    (* mixedself is only relevant for classes.
     * for adt's we later test the variance of Carrier to be 
     * strictly positive. Any Carrier in a mixed variance position would 
     * make this test fail.
     *)
  let ignore = ref false
  in
    do_type ignore selfinfo start_var NoCheck c#get_domain


let do_class ccl = 
  let parameter_ids = 
    List.map (function TypeParameter id -> id) ccl#get_parameters in
  let user_vars = 
    List.map (fun id -> 
		let v = id.id_variance
		in
		  id.id_variance <- Pair(-1,-1);
		  v
	     )
      parameter_ids
  in let selfinfo = {self_var = Pair(-1,-1);
		     cartesian_flag = true
		    }
  in let mixedselfinst = ref false
					(* get information *)
  in let _ = List.iter (do_action mixedselfinst selfinfo) 
	       (ccl#get_all_sig_actions);
  in let _ = (if !mixedselfinst then 
		ccl#put_feature MixedSelfInstFeature)
					(* record information *)
  in let functor_type = functor_of_variance 
			  selfinfo.self_var selfinfo.cartesian_flag 
			  (List.map (fun id -> id.id_variance) parameter_ids)
  in 
					(* save status *)
    ccl#set_self_variance selfinfo.self_var;
    ccl#set_functor_type functor_type;
					(* check strictness *)
    if !pedantic_mode then begin
      List.iter
	(fun tparam_id -> 
	   if not (is_strictly_positive tparam_id.id_variance) then
	     pedantic_error (remove_option tparam_id.id_token.loc)
	       ("This type parameter has variance " ^ 
		(string_of_ccsl_variance tparam_id.id_variance) ^
		"\n" ^
		"Only strictly covariant polynomial functors allowed")
	)
	parameter_ids;

      (match functor_type with
	 | StrictlyConstantFunctor
	 | StrictlyPolynomialFunctor -> ()
	     
	 | ConstantFunctor
	 | PolynomialFunctor
	 | ExtendedCartesianFunctor
	 | ExtendedPolynomialFunctor
	 | HigherOrderFunctor ->
	     pedantic_error (remove_option (ccl#get_token).loc)
	       "Only strictly covariant polynomial functors allowed"
	 | UnknownFunctor -> assert(false)
      );

      if not (ccl#has_feature FinalSemanticsFeature) then
	pedantic_error (remove_option (ccl#get_token).loc)
	  "Loose semantics not allowed" ;
    end;

					(* print a message *)
    d (" ** Class " ^ ccl#get_name ^ 
       " (" ^ 
       (string_of_functor functor_type) ^
       ") : " ^
       (string_of_ccsl_variance ccl#get_self_variance) ^ " " ^
       (if selfinfo.cartesian_flag then "t" else "f") ^ " :: " ^
       (List.fold_left 
	  (fun accu id -> 
	     accu ^ ", " ^
	     id.id_token.token_name ^
	     " : " ^ (string_of_ccsl_variance id.id_variance)
	     )
	  "" parameter_ids));
					(* check consistency *)
    assert( (valid_variance selfinfo.self_var) &&
	    List.for_all (fun id -> valid_variance id.id_variance)
	      parameter_ids);
					(* compare with user settings *)
    List.iter2 (fun id uvar -> 
		  if variance_match id.id_variance uvar 
		  then begin 
		    if uvar <> Unset
		    then
		      id.id_variance <- uvar
		  end
		  else
		    variance_mismatch_error id uvar
	       )
      parameter_ids user_vars



let do_adt adt = 
  let parameter_ids = 
    List.map (function TypeParameter id -> id) adt#get_parameters in
  let user_vars = 
    List.map (fun id -> 
		let v = id.id_variance
		in
		  id.id_variance <- Pair(-1,-1);
		  v
	     )
      parameter_ids
  in let selfinfo = {self_var = Pair(-1,-1);
		     cartesian_flag = true
		    }
					(* get information *)
  in let _ = List.iter (do_adt_constructor selfinfo)
	       (adt#get_adt_constructors);
					(* record information *)
  in let functor_type = functor_of_variance 
			  selfinfo.self_var selfinfo.cartesian_flag
			  (List.map (fun id -> id.id_variance) parameter_ids)
  in 
					(* save status *)
    adt#set_self_variance selfinfo.self_var;
    adt#set_functor_type functor_type;

					(* adt check *)
    if not (is_strictly_positive selfinfo.self_var) then
      adt_not_strict adt#get_token;

					(* check pedanticness *)
    if !pedantic_mode then begin
      List.iter
	(fun tparam_id -> 
	   if not (is_strictly_positive tparam_id.id_variance) then
	     pedantic_error (remove_option tparam_id.id_token.loc)
	       ("This type parameter has variance " ^ 
		(string_of_ccsl_variance tparam_id.id_variance) ^
		"\n" ^
		"Only strictly covariant polynomial functors allowed")
	)
	parameter_ids;
    end;
					(* print a message *)
    d (" ** Adt " ^ adt#get_name ^ 
       " (" ^ 
       (string_of_functor functor_type) ^
       ") : " ^
       (string_of_ccsl_variance adt#get_self_variance) ^ " " ^
       (if selfinfo.cartesian_flag then "t" else "f") ^ " :: " ^
       (List.fold_left 
	  (fun accu id -> 
	     accu ^ ", " ^
	     id.id_token.token_name ^
	     " : " ^ (string_of_ccsl_variance id.id_variance)
	  )
	  "" parameter_ids));
					(* check consistency *)
    assert( (valid_variance selfinfo.self_var) &&
	    List.for_all (fun id -> valid_variance id.id_variance)
	      parameter_ids);
					(* compare with user settings *)
    List.iter2 (fun id uvar -> 
		  if variance_match id.id_variance uvar 
		  then begin 
		    if uvar <> Unset
		    then
		      id.id_variance <- uvar
		  end
		  else
		    variance_mismatch_error id uvar
	       )
      parameter_ids user_vars


let do_type_def global_parameters tdef =
  let local_parameter_ids = 
    List.map (function TypeParameter id -> id) tdef.id_parameters in
  let local_user_vars = 
    List.map (fun id -> 
		let v = id.id_variance
		in
		  id.id_variance <- Pair(-1,-1);
		  v
	     )
      local_parameter_ids
					(* invalid selfinfo *)
  in let selfinfo = {self_var = Unset;	
		     cartesian_flag = true
		    }
					(* ignore mixedself argument *)
  in let ignore = ref false in
					(* get information *)
  let _ = do_type ignore selfinfo start_var NoCheck tdef.id_type
  in 
					(* print a message *)
    d (" ** Typedef " ^ tdef.id_token.token_name ^ " : " ^
       (List.fold_left 
	  (fun accu id -> 
	     accu ^ ", " ^
	     id.id_token.token_name ^
	     " : " ^ (string_of_ccsl_variance id.id_variance)
	  )
	  "" (local_parameter_ids @ global_parameters)));
					(* check consistency *)
    assert( List.for_all (fun id -> valid_variance id.id_variance)
	      local_parameter_ids);
					(* compare with user settings *)
    List.iter2 (fun id uvar -> 
		  if variance_match id.id_variance uvar 
		  then begin 
		    if uvar <> Unset
		    then
		      id.id_variance <- uvar
		  end
		  else
		    variance_mismatch_error id uvar
	       )
      local_parameter_ids local_user_vars


let do_sig si = 
  let _ = d (" ** GroundSig " ^ si#get_name ^ " :: " ^
	     (List.fold_left 
		(fun accu tp -> match tp with
		   | TypeParameter id -> 
		       accu ^ ", " ^ 
		       id.id_token.token_name ^
		       " : " ^ (string_of_ccsl_variance id.id_variance)
		)
		"" si#get_parameters))
  in let tdefs, tdecl = 
      List.partition is_type_def si#get_all_ground_types 
  in
  let global_param_ids = 
    List.map (function TypeParameter id -> id) si#get_parameters 
  in
     (* require global variance annotations if type declarations are present *)
  let _ =
      if tdecl <> []
      then
	List.iter 
	  (fun id -> (if (id.id_variance = Unset)
		      then
			variance_missing id.id_token
		     ))
	  global_param_ids
  in

  let _ =	(* require local variance annotions in all type declarations *)
    List.iter
      (fun tdecl_id -> 
	 List.iter (function TypeParameter id ->
		      (if (id.id_variance = Unset)
		       then
			 variance_missing id.id_token
		      ))
	   tdecl_id.id_parameters) 
      tdecl
  in

    (* check variances for all type definitions *)
					(* save user settings *)
  let global_user_vars = 
    List.map (fun id -> 
		let v = id.id_variance
		in
		  id.id_variance <- Pair(-1,-1);
		  v
	     )
      global_param_ids
  in 

					(* check pedanticness *)
    if !pedantic_mode then begin
      if global_param_ids <> [] then	(* non plain danger *)
	(match si#get_all_ground_types with
	   | id :: _ -> (pedantic_error (remove_option id.id_token.loc)
			   "Only plain ground signature allowed")
	   | [] -> ()
	);
      List.iter 
	(fun tcon_id ->
	   if tcon_id.id_parameters <> [] then
	      pedantic_error (remove_option tcon_id.id_token.loc)
		"Only plain ground signature allowed"
	)
	si#get_all_ground_types;
    end;

					(* get information *)
    List.iter (do_type_def global_param_ids) tdefs;

					(* print a message *)
    d (" ** GroundSig " ^ si#get_name ^ " : " ^
       (List.fold_left 
	  (fun accu id -> 
	     accu ^ ", " ^
	     id.id_token.token_name ^
	     " : " ^ (string_of_ccsl_variance id.id_variance)
	  )
	  "" global_param_ids));
					(* check consistency *)
    assert( List.for_all (fun id -> valid_variance id.id_variance)
	      global_param_ids);
					(* compare with user settings *)
    List.iter2 (fun id uvar -> 
		  if variance_match id.id_variance uvar 
		  then begin 
		    if uvar <> Unset
		    then
		      id.id_variance <- uvar
		  end
		  else
		    variance_mismatch_error id uvar
	       )
      global_param_ids global_user_vars;

					(* set local variance in definitions *)
    List.iter (fun m ->
		 List.iter (function TypeParameter id ->
			      if id.id_variance = Unset then
				id.id_variance <- Unused
			   )
		   m#get_local_parameters
	      )
      si#get_members


let variance_ast = function
  | CCSL_class_dec cl -> do_class cl
  | CCSL_adt_dec adt -> do_adt adt
  | CCSL_sig_dec si -> do_sig si


let ccsl_variance_pass (ast: Classtypes.ccsl_ast list) = 
  List.iter variance_ast ast



(*** Local Variables: ***)
(*** version-control: t ***)
(*** kept-new-versions: 5 ***)
(*** delete-old-versions: t ***)
(*** time-stamp-line-limit: 30 ***)
(*** End: ***)

