Mihaiii commited on
Commit
7a641cc
·
verified ·
1 Parent(s): d0e140f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -7
app.py CHANGED
@@ -15,13 +15,10 @@ model_name = "unsloth/Llama-3.2-1B-Instruct"
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
 
17
  # Load two instances of the model on CUDA for parallel inference
18
- model1 = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
19
 
20
- model2 = AutoModelForCausalLM.from_pretrained(model_name)
21
- device = torch.device('cuda')
22
-
23
- strategy = CreativeWritingStrategy()
24
- provider = TransformersProvider(model2, tokenizer, device)
25
  creative_sampler = BacktrackSampler(strategy, provider)
26
 
27
  # Helper function to create message array for the chat template
@@ -52,7 +49,7 @@ def generate_responses(prompt, history):
52
  return tokenizer.decode(generated_list, skip_special_tokens=True)
53
 
54
  custom_output = asyncio.run(custom_sampler_task())
55
- standard_output = model1.generate(inputs, max_length=2048, temperature=1)
56
  # Decode standard output and remove the prompt from the generated response
57
  standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
58
 
 
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
 
17
  # Load two instances of the model on CUDA for parallel inference
18
+ model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
19
 
20
+ provider = TransformersProvider(model, tokenizer, device)
21
+ strategy = CreativeWritingStrategy(provider)
 
 
 
22
  creative_sampler = BacktrackSampler(strategy, provider)
23
 
24
  # Helper function to create message array for the chat template
 
49
  return tokenizer.decode(generated_list, skip_special_tokens=True)
50
 
51
  custom_output = asyncio.run(custom_sampler_task())
52
+ standard_output = model.generate(inputs, max_length=2048, temperature=1)
53
  # Decode standard output and remove the prompt from the generated response
54
  standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
55