Skip to content

Commit eeda42e

Browse files
Issue 88 and 120 - Fix
1 parent ea7efb1 commit eeda42e

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

precise/network_runner.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,18 @@ def run(self, inp: np.ndarray) -> float:
6868

6969
class KerasRunner(Runner):
7070
def __init__(self, model_name: str):
71-
import os
72-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
7371
import tensorflow as tf
72+
# ISSUE 88 - Following 3 lines added to resolve issue 88 - JM 2020-02-04 per liny90626
73+
from tensorflow.python.keras.backend import set_session # ISSUE 88
74+
self.sess = tf.Session() # ISSUE 88
75+
set_session(self.sess) # ISSUE 88
7476
self.model = load_precise_model(model_name)
7577
self.graph = tf.get_default_graph()
7678

7779
def predict(self, inputs: np.ndarray):
80+
from tensorflow.python.keras.backend import set_session # ISSUE 88
7881
with self.graph.as_default():
82+
set_session(self.sess) # ISSUE 88
7983
return self.model.predict(inputs)
8084

8185
def run(self, inp: np.ndarray) -> float:

0 commit comments

Comments
 (0)