(*
 * The LOOP Project
 *
 * The LOOP Team, Dresden University and Nijmegen University
 *
 * Copyright (C) 2002
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2 of
 * the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
 * General Public License in file COPYING in this or one of the
 * parent directories for more details.
 *
 * adopted to Revision on 2.1.99 by Hendrik
 *
 * Time-stamp: <Thursday 16 May 02 15:41:20 tews@ithif51>
 *
 * Utility functions for ccsl parser
 *
 * $Id: table.ml,v 1.8 2002/05/22 13:42:46 tews Exp $
 *
 *)

open Util;;

(***********************************************************************
 *
 * Module specification
 *)


(******************************************************************
 * Defining types
 *)


       (* This symboltable is polymorph in three types:
        * 'space : an enumeration of the different namespaces
        * 'shape : a type information used for overloading
        * 'symbol: the internal representation this table manages
        *)

type ('space, 'shape, 'symbol) block_type = 
    ( ('space*string), ('shape*'symbol) list ref) Hashtbl.t 

and  ('space, 'shape, 'symbol) local =
     ('space, 'shape, 'symbol) block_type 
       *
     ('shape -> 'shape -> bool)


and  ('space, 'shape, 'symbol) global = {
  mutable size              : int;
  mutable top               : int;
  mutable add_to            : int;
          write_stack       : int Stack.t;
  mutable table : ('space, 'shape, 'symbol) block_type array;
          eq                : 'shape -> 'shape -> bool
}

(******************************************************************
 *
 * Utility section
 *)

let create_hash () = (Hashtbl.create 53);;

    (* same as @ but reverses the first argument, equvalent to 
     * (rev l1) @ l2 but faster
     *)
let rec rev_concat l1 l2 = match l1 with
  | [] -> l2
  | a :: l1 -> rev_concat l1 (a::l2)


      (* if true print comments on block creation and closing *)
let debug_blocks = false


(******************************************************************
 *
 * Create a new symboltable
 *
 * Creating a new table
 * 
 * the argument is suposed to be an equality predicate on
 * the type 'shape used for overloading
 *
 * val new_table : ('shape -> 'shape -> bool) ->
 *                     -> ('space, 'shape, 'symbol) global
 *)

let new_table eq =
  let default_size = 15 in
  let t = { size = default_size;
            top  = -1;
            add_to = -1;
            write_stack = Stack.create();
            table = Array.create default_size (create_hash ());
            eq = eq } 
  in begin
       for n = 1 to default_size -1 do
         t.table.(n) <- create_hash()
       done;
       t
     end

let table_non_empty st = (st.top >= 0) && (st.add_to >= 0)

let new_local st = create_hash(), st.eq


(**************************
 *
 * double the size of the symboltable
 *)

let enlarge_table st =
  let ntable = Array.create (st.size * 2) (create_hash ()) in
    Array.blit st.table 0 ntable 0 st.size;
    for n = st.size +1 to (st.size * 2) -1 do
      ntable.(n) <- create_hash()
    done;
    st.size <- (st.size * 2);
    st.table <- ntable


(******************************************************************
 * The Interface 
 *)

       (* exception if a string is not found in a table *)
exception Not_defined

    (* exception to be raised if a symbol is overloaded,
       but see documentation for find 
     *)
exception Overloaded


(*******************************************************************
 *
 * Creating entries in the Table
 *)
    (* (create name x) creates an new entry for name in 
     *  table. 
     *
     * val create : ('space, 'shape, 'symbol) global ->
     *   'space -> string -> 'shape -> 'symbol -> unit
     *)
let create st space name shape symbol = 
  assert(table_non_empty st);
  Hashtbl.add (st.table.(st.add_to)) (space,name) (ref [(shape,symbol)])

    (* overload an entry with new symbol, 
     *  create it, if not already done in this block
     *  
     * val local_overload : ('space, 'shape, 'symbol) local -> 
     *   'space -> string -> 'shape -> 'symbol -> unit
     *)
let local_overload block space name shape symbol =
  let (neu,sym_list) = (try (false, Hashtbl.find (fst block) (space,name))
                        with Not_found -> (true,ref []))
  in sym_list := (shape, symbol) :: !sym_list;
    if neu then Hashtbl.add (fst block) (space, name) sym_list


    (* overload an entry with new symbol, 
     *  create it, if not already done in this block
     *  
     * val overload : ('space, 'shape, 'symbol) global ->
     *   'space -> string -> 'shape -> 'symbol -> unit
     *)
let overload st space name shape symbol =
  assert(table_non_empty st);
  local_overload ((st.table.(st.add_to)),st.eq) space name shape symbol


    (* do lookup in a local table 
     *
     * val find_local : ('space, 'shape, 'symbol) local -> 
     *   'space -> string -> 'symbol
     *)
let find_local block space name =
  try match !(Hashtbl.find (fst block) (space,name)) with
    | [] -> raise Not_defined
    | (_,symbol)::_ -> symbol
  with Not_found -> raise Not_defined


    (* look for an overloaded symbol in a local table
     *
     * val find_local_overloaded : ('space, 'shape, 'symbol) local ->
     *   'space -> string -> 'shape -> 'symbol
     *)
let find_local_overloaded block space name shape =
  let rec assoc = function
    | [] -> raise Not_defined
    | (s,n) :: l -> if (snd block) s shape then n else assoc l
  in
    try assoc (!(Hashtbl.find (fst block) (space,name)))
    with Not_found -> raise Not_defined


    (* find all symbols of a given name in local table 
     *
     * val find_all_local : ('space, 'shape, 'symbol) local -> 
     *   'space -> string -> ('shape * 'symbol) list
     *)
let find_all_local block space name =
  try 
    match !(Hashtbl.find (fst block) (space,name)) with
      | [] -> raise Not_defined
      | x -> x
  with Not_found -> raise Not_defined


    (* find all symbols of a given name 
     *
     * val find_all : ('space, 'shape, 'symbol) global ->
     *   'space -> string -> ('shape*'symbol) list
     *) 
let find_all st space name =
  let _ = assert(table_non_empty st) in
  let rec do_it res = function
    | -1 -> res
    | n -> (let res' = (try res @ !(Hashtbl.find (st.table.(n))
                                      (space,name))
                        with Not_found -> res)
            in do_it res' (n-1))
  in
    match do_it [] st.top with
      | [] -> raise Not_defined
      | res -> res



    (* look for a non overloaded symbol 
     *
     * val find : ('space, 'shape, 'symbol) global ->
     *   'space -> string -> 'symbol
     *)
let find st space name = 
  let _ = assert(table_non_empty st) in
  let rec do_it = function
    | -1 -> raise Not_defined
    | n  -> (try 
	       match !(Hashtbl.find st.table.(n) (space,name)) with
		 | [] -> do_it (n-1)
		 | x -> x
             with Not_found -> do_it (n-1))
  in match do_it st.top with
    | [] -> raise Not_defined               (* dead branch *)
    | (_,symbol)::_ -> symbol


    (* look for a non overloaded symbol, 
     * if found, return both, the local block and the symbol
     * 
     *)
let find_with_block st space name =
  let _ = assert(table_non_empty st) in
  let rec do_it = function
    | -1 -> raise Not_defined
    | n  -> (try 
	       match !(Hashtbl.find st.table.(n) (space,name)) with
		 | [] -> do_it (n-1)
		 | x -> x, n
             with Not_found -> do_it (n-1))
  in match do_it st.top with
    | [],_ -> raise Not_defined               (* dead branch *)
    | ((_,symbol)::_, n) -> symbol, (st.table.(n), st.eq)


    (* find all symbols of a given name, together with their local block
     *
     * 
     *
     *) 
let find_all_with_block st space name =
  let _ = assert(table_non_empty st) in
  let rec do_it res = function
    | -1 -> res
    | n -> (let res' = 
	      (try res @ 
		 (List.map 
		    (fun (sh,sym) -> (sh, sym, (st.table.(n), st.eq)))
		    !(Hashtbl.find (st.table.(n))
                        (space,name)))
               with Not_found -> res)
            in do_it res' (n-1))
  in
    match do_it [] st.top with
      | [] -> raise Not_defined
      | res -> res



    (* find a symbol with type information 
     *
     * val find_overloaded : ('space, 'shape, 'symbol) global ->
     *   'space -> string -> 'shape -> 'symbol
     *)
let find_overloaded (st:('sp,'sh,'sy)global) (space:'sp) name (shape:'sh) =
  let _ = assert(table_non_empty st) in
  let rec assoc i = function
    | [] -> (do_it (i-1) : 'sy)
    | (s,n) :: l -> if st.eq shape s then (n : 'sy) else assoc i l 
  and do_it = function
    | -1 -> raise Not_defined
    | i  -> (try assoc i !( Hashtbl.find st.table.(i) (space,name)) 
             with Not_found -> do_it (i-1))
  in
    (do_it st.top : 'sy)
      

    (* delete an overloaded symbol from a local table
     *
     * val del_local_overloaded : ('space, 'shape, 'symbol) local ->
     *   'space -> string -> 'shape -> unit
     *)
let delete_local_overloaded block space name shape =
  let rec del_assoc syms = function
    | [] -> raise Not_defined
    | ((t,s) as sym) :: l -> if (snd block) t shape 
      then 
        rev_concat syms l
      else del_assoc (sym::syms) l
  in
    try 
      let sym_ref =  Hashtbl.find (fst block) (space,name) in
        sym_ref := del_assoc [] !sym_ref
    with Not_found -> raise Not_defined


    (* iterate over all entries 
     *
     * val iter : ('space,'shape,'symbol) global -> 
     *   ('space -> string -> 'shape -> 'symbol -> 'unit) -> unit
     *)
let iter st f = 
  for n = st.top downto 0 do
    Hashtbl.iter 
      (fun (space,name) sym_list -> 
         List.iter (fun (shape,symbol) -> f space name shape symbol) 
           !sym_list) (st.table.(n))
  done

    (* start a new block -> make entries into a local st.table 
     *
     * val start_block : ('space, 'shape, 'symbol) global -> 
     *   ('space, 'shape, 'symbol) local
     *)
let start_block st = begin
  Stack.push (st.add_to) st.write_stack;
  st.top <- st.top +1;
  st.add_to <- st.top;
  if debug_blocks then
    prerr_endline ("Start Block " ^ (string_of_int st.top));
  if st.top = st.size then
    begin
      enlarge_table st;
    end;
  (st.table.(st.top), st.eq)
end
  

    (* end the definition section of the current block
     * make entries into the same block 
     * like before start_block was called
     *
     * val end_of_defs : ('space, 'shape, 'symbol) global -> 
     *   ('space, 'shape, 'symbol) local
     *)
let end_of_defs st = 
  let _ = assert(st.add_to >= 0) in
  let _ = assert(st.add_to <> st.top) in 
  let defs = st.table.(st.add_to) in
    begin
      st.add_to <-  Stack.pop st.write_stack;
      (defs, st.eq)
    end

      
    (* close the current block 
     *
     * val close_block : ('space, 'shape, 'symbol) global -> 
     *   ('space, 'shape, 'symbol) local
     *)
let close_block st =
  let _ = assert(st.top >= 0) in
  let _ = 
    if debug_blocks then
      prerr_endline ("Close Block " ^ (string_of_int st.top)) in
  let defs = st.table.(st.add_to) in
    begin
      st.table.(st.top) <- create_hash();
      if st.top = st.add_to then
        st.add_to <- Stack.pop st.write_stack;
      st.top <- st.top -1;
      (defs, st.eq)
    end


    (* Close some blocks in case of errors during computation.
     * Only keep keep_blocks number ones
     *
     * val reset_gst : ('space, 'shape, 'symbol) global -> unit
     *)
let reset_gst st keep_blocks =
  assert(keep_blocks >= 0);
  for i = st.top downto keep_blocks do
    ignore(close_block st)
  done
  

    (* nesting scopes: add this local name space to the global 
     * symboltable, use close_block to remove it
     *
     * val open_block : ('space, 'shape, 'symbol) global -> 
     *                 ('space, 'shape, 'symbol) local -> unit
     *)
let open_block st lt = begin
  Stack.push (st.add_to) st.write_stack;
  st.top <- st.top +1;
  st.add_to <- st.top;
  if st.top = st.size then
    begin
      enlarge_table st;
    end;
  st.table.(st.top) <- (fst lt)
end
  
    (* 	 return the number of open blocks
     * 
     * val nesting_size : ('space, 'shape, 'symbol) global -> int
     *)
let nesting_size st = st.top +1

    (* dump the whole table 
     *
     * val dump  : ('space, 'shape, 'symbol) global ->
     *   ('symbol -> string) -> string
     *)
let dump st f = 
  let res = Buffer.create 100 in
  let o s = Buffer.add_string res s in
    begin
      o ("Currently " ^ (string_of_int (st.top +1)) ^ 
	 " blocks out of " ^ (string_of_int (Array.length st.table)) ^
	 "\n");
      o ("Writing into " ^ (string_of_int st.add_to) ^ "\n");
      for n = st.top downto 0 do 
        o ("Contents " ^ (string_of_int n) ^ ":" );
        o "{ ";
        Hashtbl.iter 
          (fun _ sym_list -> 
             List.iter (fun (_,symbol) ->  o (f symbol); o "; ") 
               !sym_list) (st.table.(n));
        o " }\n"
      done;
      Buffer.contents res
    end


let print st f c = output_string c (dump st f)


       (* dump a local table 
        *
        * val print_local : ('space, 'shape, 'symbol) local ->
        *      ('symbol -> string) -> out_channel -> unit
        *)
let dump_local block f =
  let res = Buffer.create 20 in
  let o s = Buffer.add_string res s in
  begin
    o "[";
    Hashtbl.iter 
      (fun _ sym_list -> 
         List.iter (fun (shape,symbol) ->  o (f symbol); o " | ") 
           !sym_list) (fst block);
    o "]";
    Buffer.contents res
  end


       (* print a local table 
        *
        * val print_local : ('space, 'shape, 'symbol) local ->
        *      ('shape -> 'symbol -> string) -> out_channel -> unit
        *)
let string_of_local block f =
  let res = ref "[" in
  let o s = res := !res ^ s in
    Hashtbl.iter 
      (fun _ sym_list -> 
         List.iter (fun (shape,symbol) ->  o (f shape symbol); o " | ") 
           !sym_list) (fst block);
    !res ^ "]"


(*** Local Variables: ***)
(*** version-control: t ***)
(*** kept-new-versions: 5 ***)
(*** delete-old-versions: t ***)
(*** time-stamp-line-limit: 30 ***)
(*** End: ***)

