| | |
| | from typing import Dict, Any |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from accelerate import init_empty_weights, load_checkpoint_and_dispatch |
| |
|
| | class EndpointHandler: |
| | def __init__(self, model_dir: str, **kw): |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) |
| |
|
| | |
| | with init_empty_weights(): |
| | base = AutoModelForCausalLM.from_pretrained( |
| | model_dir, torch_dtype=torch.float16, trust_remote_code=True |
| | ) |
| |
|
| | |
| | self.model = load_checkpoint_and_dispatch( |
| | base, checkpoint=model_dir, device_map="auto", dtype=torch.float16 |
| | ).eval() |
| |
|
| | |
| | self.embed_device = self.model.get_input_embeddings().weight.device |
| | torch.cuda.set_device(self.embed_device) |
| | print(">>> embedding on", self.embed_device) |
| |
|
| | |
| | self.gen_kwargs = dict(max_new_tokens=512, temperature=0.7, top_p=0.9, do_sample=True) |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, str]: |
| | prompt = data["inputs"] |
| |
|
| | |
| | inputs = self.tokenizer(prompt, return_tensors="pt").to(self.embed_device) |
| | with torch.inference_mode(): |
| | out_ids = self.model.generate(**inputs, **self.gen_kwargs) |
| |
|
| | return {"generated_text": self.tokenizer.decode(out_ids[0], skip_special_tokens=True)} |
| |
|